view sqlpython/connections.py @ 510:c8de86e7cd06

lambdas for calling code objects
author catherine.devlin@gmail.com
date Fri, 24 Sep 2010 19:00:37 -0400
parents 85495d4d6c73
children eccf817b0fbc
line wrap: on
line source

import re
import os
import getpass
import gerald
import time
import optparse
import doctest
import pyparsing

gerald_classes = {}

try:
    import cx_Oracle
    gerald_classes['oracle'] = gerald.oracle_schema.User
except ImportError:
    pass

try:
    import psycopg2
    gerald_classes['postgres'] = gerald.PostgresSchema
except ImportError:
    pass

try:
    import MySQLdb
    gerald_classes['mysql'] = gerald.MySQLSchema
except ImportError:
    pass

#if not gerald_classes:
#    raise ImportError, 'No Python database adapters installed!'

class ObjectDescriptor(object):
    def __init__(self, name, dbobj):
        self.fullname = name
        self.dbobj = dbobj
        if hasattr(self.dbobj, 'type'):
            self.type = self.dbobj.type.lower()
        else:
            self.type = str(type(self.dbobj)).split('.')[-1].lower().strip("'>")
        self.path = '%s/%s' % (self.type, self.fullname)
        if '.' in self.fullname:
            (self.owner, self.unqualified_name) = self.fullname.split('.')
            self.owner = self.owner.lower()        
        else:
            (self.owner, self.unqualified_name) = (None, self.fullname)        
        self.unqualified_path = '%s/%s' % (self.type, self.unqualified_name)
    def match_pattern(self, pattern, specific_owner=None):
        right_owner = (not self.owner) or (not specific_owner) or (self.owner == specific_owner.lower())
        if not pattern:
            return right_owner        
        compiled = re.compile(pattern, re.IGNORECASE)            
        if r'\.' in pattern:
            return compiled.match(self.fullname) or compiled.match(self.path)
        return right_owner and (compiled.match(self.type) or 
                                compiled.match(self.type + r'/') or
                                 compiled.match(self.unqualified_name) or
                                 compiled.match(self.unqualified_path))
       
class DBOjbect(object):
    def __init__(self, schema, object_type, name):
        self.schema = schema
        self.type = object_type
        self.name = name
        
class OptionTestDummy(object):
    mysql = None
    postgres = None
    username = None
    password = None
    hostname = None
    port = None
    database = None
    mode = 0
    def __init__(self, *args, **kwargs):
        self.__dict__.update(kwargs)
        
class DatabaseInstance(object):
    username = None
    password = None
    hostname = None
    port = None
    database = None
    mode = 0
    connection_uri_parser = re.compile('(?P<rdbms>postgres|oracle|mysql|sqlite|mssql)://?(?P<connect_string>.*$)', re.IGNORECASE)    
    oracle_style_connection_parser = re.compile('(?P<username>[^/\s@]*)(/(?P<password>[^/\s@]*))?(@((?P<hostname>[^/\s:]*)(:(?P<port>\d{1,4}))?/)?(?P<database>[^/\s:]*))?(\s+as\s+(?P<mode>sys(dba|oper)))?',
                                     re.IGNORECASE)
    connection_parser = re.compile('((?P<database>\S+)(\s+(?P<username>\S+))?)?')    
    def __init__(self, arg, opts, default_rdbms = 'oracle'):
        'no docstring'
        '''
        >>> opts = OptionTestDummy(postgres=True, password='password')        
        >>> DatabaseInstance('thedatabase theuser', opts).uri()        
        'postgres://theuser:password@localhost:5432/thedatabase'
        >>> opts = OptionTestDummy(password='password')
        >>> DatabaseInstance('oracle://user:password@db', opts).uri()        
        'oracle://user:password@db'
        >>> DatabaseInstance('user/password@db', opts).uri()
        'oracle://user:password@db'
        >>> DatabaseInstance('user/password@db as sysdba', opts).uri()
        'oracle://user:password@db?mode=2'
        >>> DatabaseInstance('user/password@thehost/db', opts).uri()        
        'oracle://user:password@thehost:1521/db'
        >>> opts = OptionTestDummy(postgres=True, hostname='thehost', password='password')
        >>> DatabaseInstance('thedatabase theuser', opts).uri()        
        'postgres://theuser:password@thehost:5432/thedatabase'
        >>> opts = OptionTestDummy(mysql=True, password='password')
        >>> DatabaseInstance('thedatabase theuser', opts).uri()        
        'mysql://theuser:password@localhost:3306/thedatabase'
        >>> opts = OptionTestDummy(mysql=True, password='password')
        >>> DatabaseInstance('thedatabase', opts).uri()        
        'mysql://cat:password@localhost:3306/thedatabase'
        '''
        self.arg = arg
        self.opts = opts
        self.default_rdbms = default_rdbms
        self.determine_rdbms()  # may be altered later as connect string is parsed
        if not self.parse_connect_uri(arg):
            self.set_defaults()       
            connectargs = self.connection_parser.search(self.arg)
            if '@' in connectargs.group('database'):
                connectargs = OracleInstance.connection_parser.search(self.arg)
            if connectargs:
                for param in ('username', 'password', 'database', 'port', 'hostname', 'mode'):
                    if hasattr(opts, param) and getattr(opts, param):
                        setattr(self, param, getattr(opts, param))
                    else:
                        try:
                            if connectargs.group(param):
                                setattr(self, param, connectargs.group(param))
                        except IndexError:
                            pass
        self.set_corrections()     
        if not self.password:
            self.password = getpass.getpass()    
        self.connect()
    def parse_connect_uri(self, uri):
        results = self.connection_uri_parser.search(uri)
        if results:
            self.set_class_from_rdbms_name(results.group('rdbms'))     
            r = gerald.utilities.dburi.Connection().parse_uri(results.group('connect_string'))
            self.username = r.get('user') or self.username
            self.password = r.get('password') or self.password
            self.hostname = r.get('host') or self.hostname
            self.port = self.port or self.default_port
            self.database = r.get('db_name')
            return True
        else:
            return False
    def set_class_from_rdbms_name(self, rdbms_name):
        for cls in (OracleInstance, PostgresInstance, MySQLInstance):
            if cls.rdbms == rdbms_name:
                self.__class__ = cls        
    def uri(self):
        return '%s://%s:%s@%s:%s/%s' % (self.rdbms, self.username, self.password,
                                         self.hostname, self.port, self.database)  
    def determine_rdbms(self):
        if self.opts.mysql or self.arg.startswith('mysql://'):
            self.__class__ = MySQLInstance
        elif self.opts.postgres or self.arg.startswith('postgres://') or self.arg.startswith('postgresql://'):
            self.__class__ = PostgresInstance
        else:
            self.set_class_from_rdbms_name(self.default_rdbms)     
    def set_defaults(self):
        self.port = self.default_port
    def set_corrections(self):
        pass
    def set_instance_number(self, instance_number):
        self.instance_number = instance_number
        self.prompt = "%d:%s@%s> " % (self.instance_number, self.username, self.database)  
    sqlname = pyparsing.Word(pyparsing.alphas + '$_#%*', pyparsing.alphanums + '$_#%*')
    ls_parser = ( (pyparsing.Optional(sqlname("owner") + "/") + 
                   pyparsing.Optional(sqlname("type") + "/") + 
                   pyparsing.Optional(sqlname("name")) +
                   pyparsing.stringEnd ) 
                   | ( pyparsing.Optional(sqlname("type") + "/") + 
                       pyparsing.Optional(sqlname("owner") + ".") +
                       pyparsing.Optional(sqlname("name")) +
                       pyparsing.stringEnd ))
    identifier_regex = re.compile(
                       r'((?P<object_type>DATABASE LINK|DIRECTORY|FUNCTION|INDEX|JOB|MATERIALIZED VIEW|PACKAGE|PROCEDURE|SEQUENCE|SYNONYM|TABLE|TRIGGER|TYPE|VIEW|BASE TABLE)($|[\\/.\s])+)?(?P<remainder>.*)',
                       re.IGNORECASE)
    def comparison_operator(self, target):
        if ('%' in target) or ('_' in target):
            operator = 'LIKE'
        else:
            operator = '='
        return operator
    def sql_format_wildcards(self, target):
        return target.replace('*', '%').replace('?', '_')
    def comparitor(self, target):
        if '%' in target or '_' in target:
            return 'LIKE'
        else:
            return '='
    def objects(self, target, opts):
        match = self.identifier_regex.search(target)
        object_type = self.name_case(match.group('object_type') or '%')
        names = [n.strip() or '%' for n in self.name_case(match.group('remainder').replace('*', '%').replace('?', '_')).split('.')] + ['%', '%']
        replacements = {'name1_comparitor': self.comparitor(names[0]),
                        'name2_comparitor': self.comparitor(names[1]),
                        'object_type_comparitor': self.comparitor(object_type),
                        'sort': 'ASC',
                        'all': '1 = 0'}
        if hasattr(opts, 'reverse') and opts.reverse:
            replacements['sort'] = 'DESC'
        if hasattr(opts, 'all') and opts.all:
            replacements['all'] = '1 = 1'
        qry = self.all_object_qry % replacements
        binds = {'schema': self.name_case(self.username), 'object_type': object_type,
                 'name1': names[0], 'name2': names[1]}
        curs = self.connection.cursor()
        curs.execute(qry, binds) 
        return curs
    def columns(self, target, opts):
        target = self.sql_format_wildcards(target)
        if opts.all:
            owner = '%'
        else:
            owner = self.username
        qry = self.column_qry % (self.comparison_operator(owner), self.bindSyntax('owner'),
                                 self.comparison_operator(target), self.bindSyntax('colname'))
        binds = (('owner', owner), ('colname', target))
        curs = self.connection.cursor()
        curs.execute(qry, self.bindVariables(binds))
        return curs
    def source(self, target, opts):
        if opts.all:
            owner = '%'
        else:
            owner = self.username
        qry = self.source_qry % (self.comparison_operator(owner), self.bindSyntax('owner'),
                                                                  self.bindSyntax('target'))
        binds = (('owner', owner), ('target', target))
        curs = self.connection.cursor()
        curs.execute(qry, self.bindVariables(binds))
        return curs
        
    gerald_types = {'TABLE': gerald.oracle_schema.Table,
                    'VIEW': gerald.oracle_schema.View}
    def object_metadata(self, owner, object_type, name):
        if object_type in self.gerald_types:
            return self.gerald_types[object_type](name, self.connection.cursor(), owner)
        else:
            raise NotImplementedError, '%s not implemented for this RDBMS' % object_type

parser = optparse.OptionParser()
parser.add_option('--postgres', action='store_true', help='Connect to postgreSQL: `connect --postgres [DBNAME [USERNAME]]`')
parser.add_option('--oracle', action='store_true', help='Connect to an Oracle database')
parser.add_option('--mysql', action='store_true', help='Connect to a MySQL database')
parser.add_option('-H', '--hostname', type='string',
                                    help='Machine where database is hosted')
parser.add_option('-p', '--port', type='int',
                                    help='Port to connect to')
parser.add_option('--password', type='string',
                                    help='Password')
parser.add_option('-d', '--database', type='string',
                                    help='Database name to connect to')
parser.add_option('-U', '--username', type='string',
                                    help='Database user name to connect as')

def connect(connstr):
    (options, args) = parser.parse_args(connstr)
    print options
    print args

class MySQLInstance(DatabaseInstance):
    rdbms = 'mysql'
    default_port = 3306
    def set_defaults(self):
        self.port = self.default_port       
        self.hostname = 'localhost'
        self.username = os.getenv('USER')
        self.database = os.getenv('USER')
    def connect(self):
        self.connection = MySQLdb.connect(host = self.hostname, user = self.username, 
                                passwd = self.password, db = self.database,
                                port = self.port, sql_mode = 'ANSI')        
    def bindSyntax(self, varname):
        return '%s'
    def bindVariables(self, binds):
        'Puts a tuple of (name, value) pairs into the bind format desired by MySQL'
        return (i[1] for i in binds)
    column_qry = """SELECT atc.owner, ao.object_type, atc.table_name, atc.column_name      
                    FROM   all_tab_columns atc
                    JOIN   all_objects ao ON (atc.table_name = ao.object_name AND atc.owner = ao.owner)
                    WHERE  owner %s %s
                    AND    column_name %s %s """
    source_qry = """SELECT owner, type, name, line, text
                    FROM   all_source
                    WHERE  owner %s %s
                    AND    UPPER(text) LIKE %s"""
    
class PostgresInstance(DatabaseInstance):
    rdbms = 'postgres'
    default_port = 5432
    def name_case(self, s):
        return s.lower()
    def set_defaults(self):
        self.port = os.getenv('PGPORT') or self.default_port
        self.database = os.getenv('ORACLE_SID')
        self.hostname = os.getenv('PGHOST') or 'localhost'
        self.username = os.getenv('USER')
    def connect(self):
        self.connection = psycopg2.connect(host = self.hostname, user = self.username, 
                                 password = self.password, database = self.database,
                                 port = self.port)          
    def bindSyntax(self, varname):
        return '%%(%s)s' % varname
    def bindVariables(self, binds):
        'Puts a tuple of (name, value) pairs into the bind format desired by psycopg2'
        return dict((b[0], b[1].lower()) for b in binds)
    all_object_qry = """SELECT * FROM (
                        SELECT ns.nspname AS schema, 
                               CASE c.relkind
                               WHEN 'r' THEN 'table'
                               WHEN 'i' THEN 'index'
                               WHEN 'S' THEN 'sequence'
                               WHEN 'v' THEN 'view'
                               WHEN 'c' THEN 'composite type'
                               WHEN 't' THEN 'toast table'
                               END AS object_type, 
                               c.relname AS object_name, 
                               NULL AS synonym_name
                        FROM   pg_namespace ns
                        JOIN   pg_class c ON (ns.oid = c.relnamespace)
                        WHERE  ( (     ns.nspname = %%(schema)s
                                    OR position(ns.nspname in (SELECT setting FROM pg_settings WHERE name = 'search_path')) > 0 
                                    OR %(all)s 
                                 )
                                AND    c.relname %(name1_comparitor)s %%(name1)s 
                               )
                        OR     (     ns.nspname %(name1_comparitor)s %%(name1)s
                                 AND c.relname %(name2_comparitor)s %%(name2)s 
                               )
                        ) subq
                        WHERE  object_type %(object_type_comparitor)s %%(object_type)s
                        ORDER BY object_type, schema, object_name %(sort)s"""
    column_qry = """SELECT c.table_schema, t.table_type, c.table_name, c.column_name      
                    FROM   information_schema.columns c
                    JOIN   information_schema.tables t ON (c.table_schema = t.table_schema
                                                           AND c.table_name = t.table_name)
                    WHERE  ( (c.table_schema %s %s) OR (c.table_schema = 'public'))
                    AND    c.column_name %s %s """
    source_qry = """SELECT owner, type, name, line, text
                    FROM   all_source
                    WHERE  owner %s %s
                    AND    UPPER(text) LIKE %s"""
    gerald_types = {'table': gerald.postgres_schema.Table,
                    'view': gerald.postgres_schema.View,
                    'trigger': gerald.postgres_schema.Trigger}
    
class OracleInstance(DatabaseInstance):
    rdbms = 'oracle'
    default_port = 1521
    connection_parser = re.compile('(?P<username>[^/\s@]*)(/(?P<password>[^/\s@]*))?(@((?P<hostname>[^/\s:]*)(:(?P<port>\d{1,4}))?/)?(?P<database>[^/\s:]*))?(\s+as\s+(?P<mode>sys(dba|oper)))?',
                                     re.IGNORECASE)
    def name_case(self, s):
        return s.upper()
    def uri(self):
        if self.hostname:
            uri = '%s://%s:%s@%s:%s/%s' % (self.rdbms, self.username, self.password,
                                           self.hostname, self.port, self.database)           
        else:
            uri = '%s://%s:%s@%s' % (self.rdbms, self.username, self.password, self.database)
        if self.mode:
            uri = '%s?mode=%d' % (uri, self.mode)
        return uri
    def set_defaults(self):
        self.port = 1521
        self.database = os.getenv('ORACLE_SID')
    def set_corrections(self):
        if self.mode:
            self.mode = getattr(cx_Oracle, self.mode.upper())
        if self.hostname:
            self.dsn = cx_Oracle.makedsn(self.hostname, self.port, self.database)
        else:
            self.dsn = self.database
    def parse_connect_uri(self, uri):
        if DatabaseInstance.parse_connect_uri(self, uri):
            if not self.database:
                self.database = self.hostname
                self.hostname = None
                self.port = self.default_port
            return True            
        return False
    def connect(self):
        self.connection = cx_Oracle.connect(user = self.username, password = self.password,
                                  dsn = self.dsn, mode = self.mode)   
    all_object_qry = """SELECT * FROM (
                        SELECT ao.owner, ao.object_type, ao.object_name, NULL AS synonym_name
                        FROM   all_objects ao
                        WHERE  ao.object_type %(object_type_comparitor)s :object_type
                        AND    ao.owner = :schema 
                        AND    ao.object_name %(name1_comparitor)s :name1
                        UNION
                        SELECT asyn.table_owner, ao.object_type, asyn.table_name, asyn.synonym_name 
                        FROM   all_synonyms asyn
                        JOIN   all_objects ao ON (    asyn.table_owner = ao.owner
                                                  AND asyn.table_name = ao.object_name)
                        WHERE  %(all)s
                        AND    ao.object_type %(object_type_comparitor)s :object_type
                        AND    asyn.synonym_name %(name1_comparitor)s :name1
                        AND    asyn.owner IN (:schema, 'PUBLIC')
                        UNION
                        SELECT ao.owner, ao.object_type, ao.object_name, NULL AS synonym_name
                        FROM   all_objects ao
                        WHERE  :name1 != '%%' 
                        AND    ao.object_type %(object_type_comparitor)s :object_type
                        AND    ao.owner %(name1_comparitor)s :name1
                        AND    ao.object_name %(name2_comparitor)s :name2
                        UNION
                        SELECT asyn.table_owner, ao.object_type, asyn.table_name, asyn.synonym_name 
                        FROM   all_synonyms asyn
                        JOIN   all_objects ao ON (    asyn.table_owner = ao.owner
                                                  AND asyn.table_name = ao.object_name)
                        WHERE  %(all)s
                        AND    ao.object_type %(object_type_comparitor)s :object_type
                        AND    asyn.synonym_name %(name2_comparitor)s :name2
                        AND    asyn.owner %(name1_comparitor)s :name1
                        ) ORDER BY object_type, owner, object_name %(sort)s"""
    column_qry = """SELECT atc.owner, ao.object_type, atc.table_name, atc.column_name      
                    FROM   all_tab_columns atc
                    JOIN   all_objects ao ON (atc.table_name = ao.object_name AND atc.owner = ao.owner)
                    WHERE  atc.owner %s %s
                    AND    atc.column_name %s %s """
    source_qry = """SELECT owner, type, name, line, text
                    FROM   all_source
                    WHERE  owner %s %s
                    AND    UPPER(text) LIKE %s"""
    def bindSyntax(self, varname):
        return ':' + varname
    def bindVariables(self, binds):
        'Puts a tuple of (name, value) pairs into the bind format desired by cx_Oracle'
        return dict((b[0], b[1].upper()) for b in binds)
    gerald_types = {'TABLE': gerald.oracle_schema.Table,
                    'VIEW': gerald.oracle_schema.View,
                    'TRIGGER': gerald.oracle_schema.Trigger,
                    'SEQUENCE': gerald.oracle_schema.Sequence,
                    'PACKAGE': lambda name, cursor, owner: gerald.oracle_schema.Package(name, 'PACKAGE', cursor, owner),
                    'DATABASE LINK': gerald.oracle_schema.DatabaseLink,
                    'FUNCTION': lambda name, cursor, owner: gerald.oracle_schema.CodeObject(name, 'FUNCTION', cursor, owner),
                    'PROCEDURE': lambda name, cursor, owner: gerald.oracle_schema.CodeObject(name, 'PROCEDURE', cursor, owner),
                    }

if __name__ == '__main__':
    doctest.testmod()