Mercurial > sqlpython
changeset 440:e1a962dd7139
refactored connections
author | catherine@dellzilla |
---|---|
date | Mon, 01 Feb 2010 15:56:15 -0500 |
parents | 0a2474b76db6 |
children | da332a670378 |
files | sqlpython/connections.py sqlpython/sqlpython.py |
diffstat | 2 files changed, 192 insertions(+), 156 deletions(-) [+] |
line wrap: on
line diff
--- a/sqlpython/connections.py Mon Feb 01 10:15:28 2010 -0500 +++ b/sqlpython/connections.py Mon Feb 01 15:56:15 2010 -0500 @@ -5,6 +5,23 @@ import time import threading import pickle +import optparse +import doctest + +try: + import cx_Oracle +except ImportError: + pass + +try: + import psycopg2 +except ImportError: + pass + +try: + import MySQLdb +except ImportError: + pass class ObjectDescriptor(object): def __init__(self, name, dbobj): @@ -31,8 +48,164 @@ class GeraldPlaceholder(object): current = False - complete = False + complete = False + +class OptionTestDummy(object): + mysql = None + postgres = None + username = None + password = None + hostname = None + port = None + database = None + mode = 0 + def __init__(self, *args, **kwargs): + self.__dict__.update(kwargs) +class ConnectionData(object): + username = None + password = None + hostname = None + port = None + database = None + mode = 0 + connection_uri_parser = re.compile('(postgres|oracle|mysql|sqlite|mssql):/(.*$)', re.IGNORECASE) + connection_parser = re.compile('((?P<database>\S+)(\s+(?P<username>\S+))?)?') + def __init__(self, arg, opts, default_rdbms = 'oracle'): + ''' + >>> opts = OptionTestDummy(postgres=True, password='password') + >>> ConnectionData('thedatabase theuser', opts).uri() + 'postgres://theuser:password@localhost:5432/thedatabase' + >>> opts = OptionTestDummy(password='password') + >>> ConnectionData('oracle://user:password@db', opts).uri() + 'oracle://user:password@db' + >>> ConnectionData('user/password@db', opts).uri() + 'oracle://user:password@db' + >>> ConnectionData('user/password@db as sysdba', opts).uri() + 'oracle://user:password@db?mode=2' + >>> ConnectionData('user/password@thehost/db', opts).uri() + 'oracle://user:password@thehost:1521/db' + >>> opts = OptionTestDummy(postgres=True, hostname='thehost', password='password') + >>> ConnectionData('thedatabase theuser', opts).uri() + 'postgres://theuser:password@thehost:5432/thedatabase' + >>> opts = OptionTestDummy(mysql=True, password='password') + >>> ConnectionData('thedatabase theuser', opts).uri() + 'mysql://theuser:password@localhost:3306/thedatabase' + ''' + self.arg = arg + self.opts = opts + self.default_rdbms = default_rdbms + self.determine_rdbms() + if not self.parse_connect_uri(arg): + self.set_defaults() + connectargs = self.connection_parser.search(self.arg) + if connectargs: + for param in ('username', 'password', 'database', 'port', 'hostname', 'mode'): + if hasattr(opts, param) and getattr(opts, param): + setattr(self, param, getattr(opts, param)) + else: + try: + if connectargs.group(param): + setattr(self, param, connectargs.group(param)) + except IndexError: + pass + self.set_corrections() + if not self.password: + self.password = getpass.getpass() + def parse_connect_uri(self, uri): + results = self.connection_uri_parser.search(uri) + if results: + (self.username, self.password, self.hostname, self.port, self.database + ) = gerald.utilities.dburi.Connection().parse_uri(results.group(2)) + self.set_class_from_rdbms_name(results.group(1)) + self.port = self.port or self.default_port + return True + else: + return False + def set_class_from_rdbms_name(self, rdbms_name): + for cls in (OracleConnectionData, PostgresConnectionData, MySQLConnectionData): + if cls.rdbms == rdbms_name: + self.__class__ = cls + def uri(self): + return '%s://%s:%s@%s:%s/%s' % (self.rdbms, self.username, self.password, + self.hostname, self.port, self.database) + def gerald_uri(self): + return self.uri().split('?mode=')[0] + def determine_rdbms(self): + if self.opts.mysql: + self.__class__ = MySQLConnectionData + elif self.opts.postgres: + self.__class__ = PostgresConnectionData + else: + self.set_class_from_rdbms_name(self.default_rdbms) + def set_defaults(self): + self.port = self.default_port + def set_corrections(self): + pass + +class MySQLConnectionData(ConnectionData): + rdbms = 'mysql' + default_port = 3306 + def set_defaults(self): + self.port = self.default_port + self.hostname = 'localhost' + def connection(self): + return MySQLdb.connect(host = self.hostname, user = self.username, + passwd = self.password, db = self.database, + port = self.port, sql_mode = 'ANSI') + +class PostgresConnectionData(ConnectionData): + rdbms = 'postgres' + default_port = 5432 + def set_defaults(self): + self.port = os.getenv('PGPORT') or self.default_port + self.database = os.getenv('ORACLE_SID') + self.hostname = os.getenv('PGHOST') or 'localhost' + def connection(self): + return psycopg2.connect(host = self.hostname, user = self.username, + password = self.password, database = self.database, + port = self.port) + +class OracleConnectionData(ConnectionData): + rdbms = 'oracle' + default_port = 1521 + connection_parser = re.compile('(?P<username>[^/\s@]*)(/(?P<password>[^/\s@]*))?(@((?P<hostname>[^/\s:]*)(:(?P<port>\d{1,4}))?/)?(?P<database>[^/\s:]*))?(\s+as\s+(?P<mode>sys(dba|oper)))?', + re.IGNORECASE) + def uri(self): + if self.hostname: + uri = '%s://%s:%s@%s:%s/%s' % (self.rdbms, self.username, self.password, + self.hostname, self.port, self.database) + else: + uri = '%s://%s:%s@%s' % (self.rdbms, self.username, self.password, self.database) + if self.mode: + uri = '%s?mode=%d' % (uri, self.mode) + return uri + def set_defaults(self): + self.port = 1521 + self.database = os.getenv('ORACLE_SID') + def set_corrections(self): + if self.mode: + self.mode = getattr(cx_Oracle, self.mode.upper()) + if self.hostname: + self.dsn = cx_Oracle.makedsn(self.hostname, self.port, self.database) + else: + self.dsn = self.database + def parse_connect_uri(self, uri): + if ConnectionData.parse_connect_uri(self, uri): + if not self.database: + self.database = self.hostname + self.hostname = None + self.port = self.default_port + return True + return False + def connection(self): + return cx_Oracle.connect(user = self.username, password = self.password, + dsn = self.dsn, mode = self.mode) + +gerald_classes = {'oracle': gerald.oracle_schema.User, + 'postgres': gerald.PostgresSchema, + 'mysql': gerald.MySQLSchema } + class DatabaseInstance(object): import_failure = None username = None @@ -43,55 +216,21 @@ connection_uri_parser = re.compile('(postgres|oracle|mysql|sqlite|mssql):/(.*$)', re.IGNORECASE) def __init__(self, arg, opts, default_rdbms = 'oracle'): - opts.username = opts.username or opts.user - self.default_rdbms = default_rdbms - if not self.parse_connect_uri(arg): - self.parse_connect_arg(arg, opts) - self.connection = self.new_connection() + #opts.username = opts.username or opts.user + self.conn_data = ConnectionData(arg, opts, default_rdbms) + for v in ('username', 'database', 'rdbms'): + setattr(self, v, getattr(self.conn_data, v)) + self.connection = self.conn_data.connection() self.gerald = GeraldPlaceholder() self.discover_metadata() def discover_metadata(self): self.metadata_discovery_thread = MetadataDiscoveryThread(self) self.metadata_discovery_thread.start() - - def parse_connect_uri(self, uri): - results = self.connection_uri_parser.search(uri) - if results: - (self.username, self.password, self.host, self.port, self.db_name - ) = gerald.utilities.dburi.Connection().parse_uri(results.group(2)) - self.__class__ = rdbms_types.get(results.group(1)) - self.uri = uri - self.port = self.port or self.default_port - return True - else: - return False - - def parse_connect_arg(self, arg, opts): - self.host = opts.hostname - self.oracle_connect_mode = 0 - if opts.postgres: - self.__class__ = PostgresDatabaseInstance - elif opts.mysql: - self.__class__ = MySQLDatabaseInstance - elif opts.oracle: - self.__class__ = OracleDatabaseInstance - else: - self.__class__ = rdbms_types.get(self.default_rdbms) - self.assign_args(arg, opts) - self.db_name = opts.database or self.db_name - self.port = self.port or self.default_port - self.password = self.password or opts.password or getpass.getpass('Password: ') - self.uri = self.uri or self.calculated_uri() - def calculated_uri(self): - return '%s://%s:%s@%s:%s/%s' % (self.rdbms, self.username, self.password, - self.host, self.port, self.db_name) - def gerald_uri(self): - return self.uri.split('?mode=')[0] def set_instance_number(self, instance_number): self.instance_number = instance_number - self.prompt = "%d:%s@%s> " % (self.instance_number, self.username, self.db_name) + self.prompt = "%d:%s@%s> " % (self.instance_number, self.username, self.database) def pickle(self): try: os.mkdir(self.pickledir) @@ -102,13 +241,13 @@ picklefile.close() def picklefile(self): return os.path.join(self.pickledir, ('%s.%s.%s.%s.pickle' % - (self.rdbms, self.username, self.host, self.db_name)).lower()) + (self.rdbms, self.username, self.conn_data.hostname, self.database)).lower()) def retreive_pickled_gerald(self): picklefile = open(self.picklefile()) schema = pickle.load(picklefile) picklefile.close() - newgerald = rdbms_types[self.rdbms].gerald_class(self.username, None) - newgerald.connect(self.gerald_uri()) + newgerald = gerald_classes[self.rdbms](self.username, None) + newgerald.connect(self.conn_data.gerald_uri()) newgerald.schema = schema newgerald.current = False newgerald.complete = True @@ -116,97 +255,6 @@ for (name, obj) in newgerald.schema.items(): newgerald.descriptions[name] = ObjectDescriptor(name, obj) self.gerald = newgerald - -class OpenSourceDatabaseInstance(DatabaseInstance): - def assign_args(self, arg, opts): - self.assign_details(arg, opts) - self.username = self.username or os.environ['USER'] - self.db_name = self.db_name or self.username - self.host = opts.hostname or self.host or 'localhost' - -class ImportFailure(DatabaseInstance): - def fail(self, *arg, **kwargs): - raise ImportError, 'Python DB-API2 module (MySQLdb/psycopg2/cx_Oracle) was not successfully imported' - assign_args = fail - new_connection = fail - -try: - import psycopg2 - class PostgresDatabaseInstance(OpenSourceDatabaseInstance): - gerald_class = gerald.PostgresSchema - rdbms = 'postgres' - default_port = 5432 - def assign_details(self, arg, opts): - self.port = opts.port or os.getenv('PGPORT') or self.default_port - self.host = self.host or os.getenv('PGHOST') - args = arg.split() - if len(args) > 1: - self.username = args[1] - if len(args) > 0: - self.db_name = args[0] - def new_connection(self): - return psycopg2.connect(host = self.host, user = self.username, - password = self.password, database = self.db_name, - port = self.port) -except ImportError: - PostgresDatabaseInstance = ImportFailure - -try: - import MySQLdb - class MySQLDatabaseInstance(OpenSourceDatabaseInstance): - gerald_class = gerald.MySQLSchema - rdbms = 'mysql' - default_port = 3306 - def assign_details(self, arg, opts): - self.db_name = arg - def new_connection(self): - return MySQLdb.connect(host = self.host, user = self.username, - passwd = self.password, db = self.db_name, - port = self.port, sql_mode = 'ANSI') -except ImportError: - MySQLDatabaseInstance = ImportFailure -try: - import cx_Oracle - - class OracleDatabaseInstance(DatabaseInstance): - gerald_class = gerald.oracle_schema.User - rdbms = 'oracle' - connection_parser = re.compile('(?P<username>[^/\s@]*)(/(?P<password>[^/\s@]*))?(@((?P<host>[^/\s:]*)(:(?P<port>\d{1,4}))?/)?(?P<db_name>[^/\s:]*))?(\s+as\s+(?P<mode>sys(dba|oper)))?', - re.IGNORECASE) - connection_modes = {'SYSDBA': cx_Oracle.SYSDBA, 'SYSOPER': cx_Oracle.SYSOPER} - oracle_connect_mode = 0 - default_port = 1521 - def assign_args(self, arg, opts): - connectargs = self.connection_parser.search(arg) - self.username = connectargs.group('username') - self.password = connectargs.group('password') - self.db_name = connectargs.group('db_name') or os.getenv('ORACLE_SID') - self.port = connectargs.group('port') or self.default_port - self.host = connectargs.group('host') - if self.host: - self.dsn = cx_Oracle.makedsn(self.host, self.port, self.db_name) - else: - self.dsn = self.db_name - if connectargs.group('mode'): - self.oracle_connect_mode = self.connection_modes.get(connectargs.group('mode').upper()) - def calculated_uri(self): - if self.host: - result = DatabaseInstance.calculated_uri(self) - else: - result = '%s://%s:%s@%s' % (self.rdbms, self.username, self.password, - self.db_name) - if self.oracle_connect_mode: - result = "%s?mode=%d" % (result, self.oracle_connect_mode) - return result - def new_connection(self): - return cx_Oracle.connect(user = self.username, - password = self.password, - dsn = self.dsn, - mode = self.oracle_connect_mode) - - -except ImportError: - OracleDatabaseInstance = ImportFailure class MetadataDiscoveryThread(threading.Thread): def __init__(self, db_instance): @@ -219,7 +267,7 @@ except IOError: pass self.db_instance.gerald.current = False - newgerald = self.db_instance.gerald_class(self.db_instance.username, self.db_instance.gerald_uri()) + newgerald = gerald_classes[self.db_instance.rdbms](self.db_instance.username, self.db_instance.conn_data.gerald_uri()) newgerald.descriptions = {} for (name, obj) in newgerald.schema.items(): newgerald.descriptions[name] = ObjectDescriptor(name, obj) @@ -227,7 +275,6 @@ newgerald.complete = True self.db_instance.gerald = newgerald self.db_instance.pickle() - -rdbms_types = {'oracle': OracleDatabaseInstance, 'mysql': MySQLDatabaseInstance, 'postgres': PostgresDatabaseInstance} - - \ No newline at end of file + +if __name__ == '__main__': + doctest.testmod() \ No newline at end of file
--- a/sqlpython/sqlpython.py Mon Feb 01 10:15:28 2010 -0500 +++ b/sqlpython/sqlpython.py Mon Feb 01 15:56:15 2010 -0500 @@ -95,17 +95,6 @@ self.no_instance() legal_sql_word = pyparsing.Word(pyparsing.alphanums + '_$#') - legal_hostname = pyparsing.Word(pyparsing.alphanums + '_-.')('host') + pyparsing.Optional( - ':' + pyparsing.Word(pyparsing.nums)('port')) - oracle_connect_parser = legal_sql_word('username') + ( - pyparsing.Optional('/' + pyparsing.CharsNotIn('@')("password")) + - pyparsing.Optional('@' + pyparsing.Optional(legal_hostname + '/') + - legal_sql_word('db_name')) + - pyparsing.Optional(pyparsing.CaselessKeyword('as') + - (pyparsing.CaselessKeyword('sysoper') ^ - pyparsing.CaselessKeyword('sysdba'))('mode'))) - postgresql_connect_parser = pyparsing.Optional(legal_sql_word('db_name') + - pyparsing.Optional(legal_sql_word('username'))) def successfully_connect_to_number(self, arg): try: @@ -131,17 +120,17 @@ cmd2.make_option('--oracle', action='store_true', help='Connect to an Oracle database'), cmd2.make_option('--mysql', action='store_true', help='Connect to a MySQL database'), cmd2.make_option('-H', '--hostname', type='string', - help='Machine where database is hosted (postgresql only)'), + help='Machine where database is hosted'), cmd2.make_option('-p', '--port', type='int', - help='Port to connect to (postgresql only)'), + help='Port to connect to'), cmd2.make_option('--password', type='string', - help='Password (mysql only)'), + help='Password'), cmd2.make_option('-d', '--database', type='string', help='Database name to connect to'), cmd2.make_option('-U', '--username', type='string', help='Database user name to connect as'), - cmd2.make_option('-u', '--user', type='string', - help='Database user name to connect as') +# cmd2.make_option('-u', '--user', type='string', +# help='Database user name to connect as') ]) def do_connect(self, arg, opts):