Mercurial > sqlpython
view sqlpython/connections.py @ 510:c8de86e7cd06
lambdas for calling code objects
author | catherine.devlin@gmail.com |
---|---|
date | Fri, 24 Sep 2010 19:00:37 -0400 |
parents | 85495d4d6c73 |
children | eccf817b0fbc |
line wrap: on
line source
import re import os import getpass import gerald import time import optparse import doctest import pyparsing gerald_classes = {} try: import cx_Oracle gerald_classes['oracle'] = gerald.oracle_schema.User except ImportError: pass try: import psycopg2 gerald_classes['postgres'] = gerald.PostgresSchema except ImportError: pass try: import MySQLdb gerald_classes['mysql'] = gerald.MySQLSchema except ImportError: pass #if not gerald_classes: # raise ImportError, 'No Python database adapters installed!' class ObjectDescriptor(object): def __init__(self, name, dbobj): self.fullname = name self.dbobj = dbobj if hasattr(self.dbobj, 'type'): self.type = self.dbobj.type.lower() else: self.type = str(type(self.dbobj)).split('.')[-1].lower().strip("'>") self.path = '%s/%s' % (self.type, self.fullname) if '.' in self.fullname: (self.owner, self.unqualified_name) = self.fullname.split('.') self.owner = self.owner.lower() else: (self.owner, self.unqualified_name) = (None, self.fullname) self.unqualified_path = '%s/%s' % (self.type, self.unqualified_name) def match_pattern(self, pattern, specific_owner=None): right_owner = (not self.owner) or (not specific_owner) or (self.owner == specific_owner.lower()) if not pattern: return right_owner compiled = re.compile(pattern, re.IGNORECASE) if r'\.' in pattern: return compiled.match(self.fullname) or compiled.match(self.path) return right_owner and (compiled.match(self.type) or compiled.match(self.type + r'/') or compiled.match(self.unqualified_name) or compiled.match(self.unqualified_path)) class DBOjbect(object): def __init__(self, schema, object_type, name): self.schema = schema self.type = object_type self.name = name class OptionTestDummy(object): mysql = None postgres = None username = None password = None hostname = None port = None database = None mode = 0 def __init__(self, *args, **kwargs): self.__dict__.update(kwargs) class DatabaseInstance(object): username = None password = None hostname = None port = None database = None mode = 0 connection_uri_parser = re.compile('(?P<rdbms>postgres|oracle|mysql|sqlite|mssql)://?(?P<connect_string>.*$)', re.IGNORECASE) oracle_style_connection_parser = re.compile('(?P<username>[^/\s@]*)(/(?P<password>[^/\s@]*))?(@((?P<hostname>[^/\s:]*)(:(?P<port>\d{1,4}))?/)?(?P<database>[^/\s:]*))?(\s+as\s+(?P<mode>sys(dba|oper)))?', re.IGNORECASE) connection_parser = re.compile('((?P<database>\S+)(\s+(?P<username>\S+))?)?') def __init__(self, arg, opts, default_rdbms = 'oracle'): 'no docstring' ''' >>> opts = OptionTestDummy(postgres=True, password='password') >>> DatabaseInstance('thedatabase theuser', opts).uri() 'postgres://theuser:password@localhost:5432/thedatabase' >>> opts = OptionTestDummy(password='password') >>> DatabaseInstance('oracle://user:password@db', opts).uri() 'oracle://user:password@db' >>> DatabaseInstance('user/password@db', opts).uri() 'oracle://user:password@db' >>> DatabaseInstance('user/password@db as sysdba', opts).uri() 'oracle://user:password@db?mode=2' >>> DatabaseInstance('user/password@thehost/db', opts).uri() 'oracle://user:password@thehost:1521/db' >>> opts = OptionTestDummy(postgres=True, hostname='thehost', password='password') >>> DatabaseInstance('thedatabase theuser', opts).uri() 'postgres://theuser:password@thehost:5432/thedatabase' >>> opts = OptionTestDummy(mysql=True, password='password') >>> DatabaseInstance('thedatabase theuser', opts).uri() 'mysql://theuser:password@localhost:3306/thedatabase' >>> opts = OptionTestDummy(mysql=True, password='password') >>> DatabaseInstance('thedatabase', opts).uri() 'mysql://cat:password@localhost:3306/thedatabase' ''' self.arg = arg self.opts = opts self.default_rdbms = default_rdbms self.determine_rdbms() # may be altered later as connect string is parsed if not self.parse_connect_uri(arg): self.set_defaults() connectargs = self.connection_parser.search(self.arg) if '@' in connectargs.group('database'): connectargs = OracleInstance.connection_parser.search(self.arg) if connectargs: for param in ('username', 'password', 'database', 'port', 'hostname', 'mode'): if hasattr(opts, param) and getattr(opts, param): setattr(self, param, getattr(opts, param)) else: try: if connectargs.group(param): setattr(self, param, connectargs.group(param)) except IndexError: pass self.set_corrections() if not self.password: self.password = getpass.getpass() self.connect() def parse_connect_uri(self, uri): results = self.connection_uri_parser.search(uri) if results: self.set_class_from_rdbms_name(results.group('rdbms')) r = gerald.utilities.dburi.Connection().parse_uri(results.group('connect_string')) self.username = r.get('user') or self.username self.password = r.get('password') or self.password self.hostname = r.get('host') or self.hostname self.port = self.port or self.default_port self.database = r.get('db_name') return True else: return False def set_class_from_rdbms_name(self, rdbms_name): for cls in (OracleInstance, PostgresInstance, MySQLInstance): if cls.rdbms == rdbms_name: self.__class__ = cls def uri(self): return '%s://%s:%s@%s:%s/%s' % (self.rdbms, self.username, self.password, self.hostname, self.port, self.database) def determine_rdbms(self): if self.opts.mysql or self.arg.startswith('mysql://'): self.__class__ = MySQLInstance elif self.opts.postgres or self.arg.startswith('postgres://') or self.arg.startswith('postgresql://'): self.__class__ = PostgresInstance else: self.set_class_from_rdbms_name(self.default_rdbms) def set_defaults(self): self.port = self.default_port def set_corrections(self): pass def set_instance_number(self, instance_number): self.instance_number = instance_number self.prompt = "%d:%s@%s> " % (self.instance_number, self.username, self.database) sqlname = pyparsing.Word(pyparsing.alphas + '$_#%*', pyparsing.alphanums + '$_#%*') ls_parser = ( (pyparsing.Optional(sqlname("owner") + "/") + pyparsing.Optional(sqlname("type") + "/") + pyparsing.Optional(sqlname("name")) + pyparsing.stringEnd ) | ( pyparsing.Optional(sqlname("type") + "/") + pyparsing.Optional(sqlname("owner") + ".") + pyparsing.Optional(sqlname("name")) + pyparsing.stringEnd )) identifier_regex = re.compile( r'((?P<object_type>DATABASE LINK|DIRECTORY|FUNCTION|INDEX|JOB|MATERIALIZED VIEW|PACKAGE|PROCEDURE|SEQUENCE|SYNONYM|TABLE|TRIGGER|TYPE|VIEW|BASE TABLE)($|[\\/.\s])+)?(?P<remainder>.*)', re.IGNORECASE) def comparison_operator(self, target): if ('%' in target) or ('_' in target): operator = 'LIKE' else: operator = '=' return operator def sql_format_wildcards(self, target): return target.replace('*', '%').replace('?', '_') def comparitor(self, target): if '%' in target or '_' in target: return 'LIKE' else: return '=' def objects(self, target, opts): match = self.identifier_regex.search(target) object_type = self.name_case(match.group('object_type') or '%') names = [n.strip() or '%' for n in self.name_case(match.group('remainder').replace('*', '%').replace('?', '_')).split('.')] + ['%', '%'] replacements = {'name1_comparitor': self.comparitor(names[0]), 'name2_comparitor': self.comparitor(names[1]), 'object_type_comparitor': self.comparitor(object_type), 'sort': 'ASC', 'all': '1 = 0'} if hasattr(opts, 'reverse') and opts.reverse: replacements['sort'] = 'DESC' if hasattr(opts, 'all') and opts.all: replacements['all'] = '1 = 1' qry = self.all_object_qry % replacements binds = {'schema': self.name_case(self.username), 'object_type': object_type, 'name1': names[0], 'name2': names[1]} curs = self.connection.cursor() curs.execute(qry, binds) return curs def columns(self, target, opts): target = self.sql_format_wildcards(target) if opts.all: owner = '%' else: owner = self.username qry = self.column_qry % (self.comparison_operator(owner), self.bindSyntax('owner'), self.comparison_operator(target), self.bindSyntax('colname')) binds = (('owner', owner), ('colname', target)) curs = self.connection.cursor() curs.execute(qry, self.bindVariables(binds)) return curs def source(self, target, opts): if opts.all: owner = '%' else: owner = self.username qry = self.source_qry % (self.comparison_operator(owner), self.bindSyntax('owner'), self.bindSyntax('target')) binds = (('owner', owner), ('target', target)) curs = self.connection.cursor() curs.execute(qry, self.bindVariables(binds)) return curs gerald_types = {'TABLE': gerald.oracle_schema.Table, 'VIEW': gerald.oracle_schema.View} def object_metadata(self, owner, object_type, name): if object_type in self.gerald_types: return self.gerald_types[object_type](name, self.connection.cursor(), owner) else: raise NotImplementedError, '%s not implemented for this RDBMS' % object_type parser = optparse.OptionParser() parser.add_option('--postgres', action='store_true', help='Connect to postgreSQL: `connect --postgres [DBNAME [USERNAME]]`') parser.add_option('--oracle', action='store_true', help='Connect to an Oracle database') parser.add_option('--mysql', action='store_true', help='Connect to a MySQL database') parser.add_option('-H', '--hostname', type='string', help='Machine where database is hosted') parser.add_option('-p', '--port', type='int', help='Port to connect to') parser.add_option('--password', type='string', help='Password') parser.add_option('-d', '--database', type='string', help='Database name to connect to') parser.add_option('-U', '--username', type='string', help='Database user name to connect as') def connect(connstr): (options, args) = parser.parse_args(connstr) print options print args class MySQLInstance(DatabaseInstance): rdbms = 'mysql' default_port = 3306 def set_defaults(self): self.port = self.default_port self.hostname = 'localhost' self.username = os.getenv('USER') self.database = os.getenv('USER') def connect(self): self.connection = MySQLdb.connect(host = self.hostname, user = self.username, passwd = self.password, db = self.database, port = self.port, sql_mode = 'ANSI') def bindSyntax(self, varname): return '%s' def bindVariables(self, binds): 'Puts a tuple of (name, value) pairs into the bind format desired by MySQL' return (i[1] for i in binds) column_qry = """SELECT atc.owner, ao.object_type, atc.table_name, atc.column_name FROM all_tab_columns atc JOIN all_objects ao ON (atc.table_name = ao.object_name AND atc.owner = ao.owner) WHERE owner %s %s AND column_name %s %s """ source_qry = """SELECT owner, type, name, line, text FROM all_source WHERE owner %s %s AND UPPER(text) LIKE %s""" class PostgresInstance(DatabaseInstance): rdbms = 'postgres' default_port = 5432 def name_case(self, s): return s.lower() def set_defaults(self): self.port = os.getenv('PGPORT') or self.default_port self.database = os.getenv('ORACLE_SID') self.hostname = os.getenv('PGHOST') or 'localhost' self.username = os.getenv('USER') def connect(self): self.connection = psycopg2.connect(host = self.hostname, user = self.username, password = self.password, database = self.database, port = self.port) def bindSyntax(self, varname): return '%%(%s)s' % varname def bindVariables(self, binds): 'Puts a tuple of (name, value) pairs into the bind format desired by psycopg2' return dict((b[0], b[1].lower()) for b in binds) all_object_qry = """SELECT * FROM ( SELECT ns.nspname AS schema, CASE c.relkind WHEN 'r' THEN 'table' WHEN 'i' THEN 'index' WHEN 'S' THEN 'sequence' WHEN 'v' THEN 'view' WHEN 'c' THEN 'composite type' WHEN 't' THEN 'toast table' END AS object_type, c.relname AS object_name, NULL AS synonym_name FROM pg_namespace ns JOIN pg_class c ON (ns.oid = c.relnamespace) WHERE ( ( ns.nspname = %%(schema)s OR position(ns.nspname in (SELECT setting FROM pg_settings WHERE name = 'search_path')) > 0 OR %(all)s ) AND c.relname %(name1_comparitor)s %%(name1)s ) OR ( ns.nspname %(name1_comparitor)s %%(name1)s AND c.relname %(name2_comparitor)s %%(name2)s ) ) subq WHERE object_type %(object_type_comparitor)s %%(object_type)s ORDER BY object_type, schema, object_name %(sort)s""" column_qry = """SELECT c.table_schema, t.table_type, c.table_name, c.column_name FROM information_schema.columns c JOIN information_schema.tables t ON (c.table_schema = t.table_schema AND c.table_name = t.table_name) WHERE ( (c.table_schema %s %s) OR (c.table_schema = 'public')) AND c.column_name %s %s """ source_qry = """SELECT owner, type, name, line, text FROM all_source WHERE owner %s %s AND UPPER(text) LIKE %s""" gerald_types = {'table': gerald.postgres_schema.Table, 'view': gerald.postgres_schema.View, 'trigger': gerald.postgres_schema.Trigger} class OracleInstance(DatabaseInstance): rdbms = 'oracle' default_port = 1521 connection_parser = re.compile('(?P<username>[^/\s@]*)(/(?P<password>[^/\s@]*))?(@((?P<hostname>[^/\s:]*)(:(?P<port>\d{1,4}))?/)?(?P<database>[^/\s:]*))?(\s+as\s+(?P<mode>sys(dba|oper)))?', re.IGNORECASE) def name_case(self, s): return s.upper() def uri(self): if self.hostname: uri = '%s://%s:%s@%s:%s/%s' % (self.rdbms, self.username, self.password, self.hostname, self.port, self.database) else: uri = '%s://%s:%s@%s' % (self.rdbms, self.username, self.password, self.database) if self.mode: uri = '%s?mode=%d' % (uri, self.mode) return uri def set_defaults(self): self.port = 1521 self.database = os.getenv('ORACLE_SID') def set_corrections(self): if self.mode: self.mode = getattr(cx_Oracle, self.mode.upper()) if self.hostname: self.dsn = cx_Oracle.makedsn(self.hostname, self.port, self.database) else: self.dsn = self.database def parse_connect_uri(self, uri): if DatabaseInstance.parse_connect_uri(self, uri): if not self.database: self.database = self.hostname self.hostname = None self.port = self.default_port return True return False def connect(self): self.connection = cx_Oracle.connect(user = self.username, password = self.password, dsn = self.dsn, mode = self.mode) all_object_qry = """SELECT * FROM ( SELECT ao.owner, ao.object_type, ao.object_name, NULL AS synonym_name FROM all_objects ao WHERE ao.object_type %(object_type_comparitor)s :object_type AND ao.owner = :schema AND ao.object_name %(name1_comparitor)s :name1 UNION SELECT asyn.table_owner, ao.object_type, asyn.table_name, asyn.synonym_name FROM all_synonyms asyn JOIN all_objects ao ON ( asyn.table_owner = ao.owner AND asyn.table_name = ao.object_name) WHERE %(all)s AND ao.object_type %(object_type_comparitor)s :object_type AND asyn.synonym_name %(name1_comparitor)s :name1 AND asyn.owner IN (:schema, 'PUBLIC') UNION SELECT ao.owner, ao.object_type, ao.object_name, NULL AS synonym_name FROM all_objects ao WHERE :name1 != '%%' AND ao.object_type %(object_type_comparitor)s :object_type AND ao.owner %(name1_comparitor)s :name1 AND ao.object_name %(name2_comparitor)s :name2 UNION SELECT asyn.table_owner, ao.object_type, asyn.table_name, asyn.synonym_name FROM all_synonyms asyn JOIN all_objects ao ON ( asyn.table_owner = ao.owner AND asyn.table_name = ao.object_name) WHERE %(all)s AND ao.object_type %(object_type_comparitor)s :object_type AND asyn.synonym_name %(name2_comparitor)s :name2 AND asyn.owner %(name1_comparitor)s :name1 ) ORDER BY object_type, owner, object_name %(sort)s""" column_qry = """SELECT atc.owner, ao.object_type, atc.table_name, atc.column_name FROM all_tab_columns atc JOIN all_objects ao ON (atc.table_name = ao.object_name AND atc.owner = ao.owner) WHERE atc.owner %s %s AND atc.column_name %s %s """ source_qry = """SELECT owner, type, name, line, text FROM all_source WHERE owner %s %s AND UPPER(text) LIKE %s""" def bindSyntax(self, varname): return ':' + varname def bindVariables(self, binds): 'Puts a tuple of (name, value) pairs into the bind format desired by cx_Oracle' return dict((b[0], b[1].upper()) for b in binds) gerald_types = {'TABLE': gerald.oracle_schema.Table, 'VIEW': gerald.oracle_schema.View, 'TRIGGER': gerald.oracle_schema.Trigger, 'SEQUENCE': gerald.oracle_schema.Sequence, 'PACKAGE': lambda name, cursor, owner: gerald.oracle_schema.Package(name, 'PACKAGE', cursor, owner), 'DATABASE LINK': gerald.oracle_schema.DatabaseLink, 'FUNCTION': lambda name, cursor, owner: gerald.oracle_schema.CodeObject(name, 'FUNCTION', cursor, owner), 'PROCEDURE': lambda name, cursor, owner: gerald.oracle_schema.CodeObject(name, 'PROCEDURE', cursor, owner), } if __name__ == '__main__': doctest.testmod()