Mercurial > sqlpython
changeset 407:188c86d4a11e
struggling with option parsing
author | catherine@DellZilla |
---|---|
date | Wed, 14 Oct 2009 13:57:29 -0400 |
parents | 4b481c4293b8 |
children | 80413ef3699a |
files | sqlpython/completion.py sqlpython/schemagroup.py sqlpython/sqlpyPlus.py sqlpython/sqlpython.py |
diffstat | 4 files changed, 168 insertions(+), 82 deletions(-) [+] |
line wrap: on
line diff
--- a/sqlpython/completion.py Sun Oct 11 10:03:18 2009 -0400 +++ b/sqlpython/completion.py Wed Oct 14 13:57:29 2009 -0400 @@ -1,4 +1,4 @@ -import pyparsing, re +import pyparsing, re, doctest sqlStyleComment = pyparsing.Literal("--") + pyparsing.ZeroOrMore(pyparsing.CharsNotIn("\n")) keywords = {'order by': pyparsing.Keyword('order', caseless=True) + @@ -40,30 +40,31 @@ results.sort(cmp=lambda x,y:cmp(x[1],y[1])) return results -at_beginning = re.compile(r'^\s*\S+\s*$') +at_beginning = re.compile(r'^\s*\S+$') def whichSegment(statement): - if at_beginning.search(statement): + ''' + >>> whichSegment("SELECT col FROM t") + 'from' + >>> whichSegment("SELECT * FROM t") + 'from' + >>> whichSegment("DESC ") + 'DESC' + >>> whichSegment("DES") + 'beginning' + >>> whichSegment("") + 'beginning' + >>> whichSegment("select ") + 'select' + + ''' + if (not statement) or at_beginning.search(statement): return 'beginning' results = orderedParseResults(keywords.values(), statement) if results: return ' '.join(results[-1][0]) else: - return None - -oracleIdentifierCharacters = pyparsing.alphanums + '_#$' -def wordInProgress(statement): - result = [] - letters = list(statement) - letters.reverse() - for letter in letters: - if letter not in oracleIdentifierCharacters: - result.reverse() - return ''.join(result) - result.append(letter) - result.reverse() - return ''.join(result) - - + return statement.split(None,1)[0] + reserved = ''' access add @@ -173,4 +174,7 @@ view whenever where - with '''.split() \ No newline at end of file + with '''.split() + +if __name__ == '__main__': + doctest.testmod()
--- a/sqlpython/schemagroup.py Sun Oct 11 10:03:18 2009 -0400 +++ b/sqlpython/schemagroup.py Wed Oct 14 13:57:29 2009 -0400 @@ -47,11 +47,40 @@ [s.qual_table_names for s in self.schemas.values()], []) +class OracleSchemaAccess(object): + child_type = gerald.OracleSchema + current_database_time_query = 'SELECT sysdate FROM dual' + def latest_ddl_timestamp_query(self, username, connection): + curs = connection.cursor() + curs.execute('''SELECT owner, MAX(last_ddl_time) + FROM all_objects + GROUP BY owner + -- sort :username to top + ORDER BY REPLACE(owner, :username, 'A'), owner''', + {'username': username.upper()}) + return curs + +class PostgresSchemaAccess(object): + child_type = gerald.PostgresSchema + current_database_time_query = 'SELECT current_time' + def latest_ddl_timestamp_query(self, username, connection): + curs = connection.cursor() + curs.execute("""SELECT '%s', current_time""" % username) + return curs + +class MySQLSchemaAccess(object): + child_type = gerald.MySQLSchema + current_database_time_query = 'SELECT sysdate FROM dual' + def latest_ddl_timestamp_query(self, username, connection): + curs = connection.cursor() + curs.execute("""SELECT '%s', current_time""" % username) + return curs + class SchemaDict(dict): - schema_types = {'oracle': gerald.OracleSchema} + schema_types = {'oracle': OracleSchemaAccess, 'postgres': PostgresSchemaAccess, 'mysql': MySQLSchemaAccess} def __init__(self, dct, rdbms, user, connection, connection_string): dict.__init__(self, dct) - self.child_type = self.schema_types[rdbms] + self.schema_access = self.schema_types[rdbms]() self.user = user self.connection = connection self.gerald_connection_string = gerald_connection_string(connection_string) @@ -59,19 +88,24 @@ self.complete = 0 def refresh_asynch(self): self.refresh_thread.start() + current_database_time_sql = {gerald.OracleSchema: 'SELECT sysdate FROM dual', + gerald.PostgresSchema: 'SELECT current_time'} def get_current_database_time(self): curs = self.connection.cursor() - curs.execute('SELECT sysdate FROM dual') + curs.execute(self.schema_access.current_database_time_query) return curs.fetchone()[0] + def refresh_times(self, target_schema): + now = self.get_current_database_time() + result = [] + for (schema_name, schema) in self.items(): + if (not target_schema) or (target_schema.lower() == schema_name.lower()): + result.append('%s: %s (%s ago)' % (schema_name, schema.refreshed, now - schema.refreshed)) + result.sort() + return '\n'.join(result) + def refresh(self): current_database_time = self.get_current_database_time() - curs = self.connection.cursor() - curs.execute('''SELECT owner, MAX(last_ddl_time) - FROM all_objects - GROUP BY owner - -- sort :username to top - ORDER BY REPLACE(owner, :username, 'A'), owner''', - {'username': self.user.upper()}) + curs = self.schema_access.latest_ddl_timestamp_query(self.user, self.connection) for (owner, last_ddl_time) in curs.fetchall(): if (owner not in self) or (self[owner].refreshed < last_ddl_time): self.refresh_one(owner, current_database_time) @@ -81,11 +115,12 @@ self.column_names = [s.column_names for s in self.values()] self.columns = reduce(operator.add, [s.column_names for s in self.values()]) self.complete = 'all' - print 'metadata discovered' def refresh_one(self, owner, current_database_time=None): + #owner = owner.upper() + owner = str(owner) if not current_database_time: current_database_time = self.get_current_database_time() - self[owner] = self.child_type(owner, self.gerald_connection_string) + self[owner] = self.schema_access.child_type(owner, self.gerald_connection_string) self[owner].refreshed = current_database_time build_column_list(self[owner])
--- a/sqlpython/sqlpyPlus.py Sun Oct 11 10:03:18 2009 -0400 +++ b/sqlpython/sqlpyPlus.py Wed Oct 14 13:57:29 2009 -0400 @@ -329,6 +329,16 @@ else: return '(BLOB not saved, check bloblimit)' +class Abbreviatable_List(list): + def match(self, target): + target = target.lower() + result = [i for i in self if i.startswith(target)] + if len(result) == 0: + raise ValueError, 'None of %s start with %s' % (str(self), target) + elif len(result) > 1: + raise ValueError, 'Too many matches: %s' % str(result) + return result[0] + class sqlpyPlus(sqlpython.sqlpython): defaultExtension = 'sql' abbrev = True @@ -351,7 +361,8 @@ def __init__(self): sqlpython.sqlpython.__init__(self) self.binds = CaselessDict() - self.settable += '''autobind bloblimit colors commit_on_exit maxfetch maxtselctrows + self.settable += '''autobind bloblimit colors commit_on_exit + default_rdbms maxfetch maxtselctrows rows_remembered scan serveroutput sql_echo timeout heading wildsql version'''.split() self.settable.remove('case_insensitive') @@ -369,6 +380,8 @@ self.result_history = [] self.rows_remembered = 10000 self.bloblimit = 5 + self.default_rdbms = 'oracle' + self.rdbms_supported = Abbreviatable_List('oracle postgres mysql'.split()) self.version = 'SQLPython %s' % sqlpython.__version__ self.pystate = {'r': [], 'binds': self.binds, 'substs': self.substvars} @@ -546,7 +559,6 @@ (username, schemas) = self.metadata() segment = completion.whichSegment(line) text = text.upper() - print segment if segment in ('select', 'where', 'having', 'set', 'order by', 'group by'): completions = [c for c in schemas[username].column_names if c.startswith(text)] \ or [c for c in schemas.qual_column_names if c.startswith(text)] @@ -556,12 +568,11 @@ elif segment == 'beginning': completions = [n for n in self.get_names() if n.startswith('do_')] + [ 'insert', 'update', 'delete', 'drop', 'alter', 'begin', 'declare', 'create'] - print completions completions = [c for c in completions if c.startswith(text)] + elif segment: + completions = [t for t in schemas[username].table_names if t.startswith(text)] else: completions = [r for r in completion.reserved if r.startswith(text)] - - return completions columnlistPattern = pyparsing.SkipTo(pyparsing.CaselessKeyword('from'))('columns') + \ @@ -1522,10 +1533,12 @@ make_option('-c', '--check', action='store_true', help="Don't refresh, just check refresh status")]) def do_refresh(self, arg, opts): '''Refreshes metadata for the specified schema; only required - if table structures, etc. have changed. ''' + if table structures, etc. have changed. (sqlpython will check + for new objects, and will not waste labor if no objects have + been created or modified in a schema.)''' (username, schemas) = self.metadata() if opts.check: - print schemas.complete + self.poutput(schemas.refresh_times(arg)) return if opts.all: if opts.immediate:
--- a/sqlpython/sqlpython.py Sun Oct 11 10:03:18 2009 -0400 +++ b/sqlpython/sqlpython.py Wed Oct 14 13:57:29 2009 -0400 @@ -107,53 +107,85 @@ s = conn['schemas'] s.refresh_asynch() return conn - def ora_connect(self, arg): - modeval = 0 - oraserv = None - for modere, modevalue in self.connection_modes.items(): - if modere.search(arg): - arg = modere.sub('', arg) - modeval = modevalue - try: - orauser, oraserv = arg.split('@') - except ValueError: - try: - oraserv = os.environ['ORACLE_SID'] - except KeyError: - self.perror('instance not specified and environment variable ORACLE_SID not set') - return - orauser = arg - sid = oraserv - try: - host, sid = oraserv.split('/') - try: - host, port = host.split(':') - port = int(port) - except ValueError: - port = 1521 - oraserv = cx_Oracle.makedsn(host, port, sid) - except ValueError: - pass - try: - orauser, orapass = orauser.split('/') - except ValueError: - orapass = getpass.getpass('Password: ') - if orauser.upper() == 'SYS' and not modeval: - self.pfeedback('Privilege not specified for SYS, assuming SYSOPER') - modeval = cx_Oracle.SYSOPER - result = self.url_connect('oracle://%s:%s@%s/?mode=%d' % (orauser, orapass, oraserv, modeval)) - result['dbname'] = oraserv - return result - connection_modes = {re.compile(' AS SYSDBA', re.IGNORECASE): cx_Oracle.SYSDBA, - re.compile(' AS SYSOPER', re.IGNORECASE): cx_Oracle.SYSOPER} + legal_sql_word = pyparsing.Word(pyparsing.alphanums + '_$#') + legal_hostname = pyparsing.Word(pyparsing.alphanums + '_-.')('host') + pyparsing.Optional( + ':' + pyparsing.Word(pyparsing.nums)('port')) + oracle_connect_parser = legal_sql_word('username') + ( + pyparsing.Optional('/' + pyparsing.CharsNotIn('@')("password")) + + pyparsing.Optional('@' + pyparsing.Optional(legal_hostname + '/') + + legal_sql_word('db_name')) + + pyparsing.Optional(pyparsing.CaselessKeyword('as') + + (pyparsing.CaselessKeyword('sysoper') ^ + pyparsing.CaselessKeyword('sysdba'))('mode'))) + postgresql_connect_parser = (legal_sql_word('db_name') + + pyparsing.Optional(legal_sql_word('username'))) + + def connect_url(self, arg, opts): + rdbms = opts.rdbms or self.default_rdbms + + mode = 0 + host = None + port = None + + if rdbms == 'oracle': + result = self.oracle_connect_parser.parseString(arg) + if result.mode == 'sysdba': + mode = cx_Oracle.SYSDBA + elif result.mode == 'sysoper': + mode = cx_Oracle.SYSOPER + else: + mode = 0 + elif rdbms == 'postgres': + result = self.postgresql_connect_parser.parseString(arg) + port = opts.port or os.environ.get('PGPORT') or 5432 + host = opts.host or os.environ.get('PGHOST') or 'localhost' + + username = result.username or opts.username + if not username and rdbms == 'postgres': + username = os.environ.get('PGUSER') or os.environ.get('USER') + + db_name = result.db_name or opts.database + if not db_name: + if rdbms == 'oracle': + db_name = os.environ.get('ORACLE_SID') + elif rdbms == 'postgres': + db_name = os.environ.get('PGDATABASE') or username + + password = result.password or getpass.getpass('Password: ') + + if host: + if port: + host = '%s:%s' % (host, port) + db_name = '%s/%s' % (host, db_name) + + url = '%s://%s:%s@%s' % (rdbms, username, password, db_name) + if mode: + url = '%s/?mode=%d' % mode + return url + @cmd2.options([cmd2.make_option('-a', '--add', action='store_true', help='add connection (keep current connection)'), cmd2.make_option('-c', '--close', action='store_true', help='close connection {N} (or current)'), cmd2.make_option('-C', '--closeall', action='store_true', - help='close all connections'),]) + help='close all connections'), + cmd2.make_option('--postgres', help='Connect to a postgreSQL database'), + cmd2.make_option('--oracle', help='Connect to an Oracle database'), + cmd2.make_option('--mysql', help='Connect to a MySQL database'), + cmd2.make_option('-r', '--rdbms', type='string', + help='Type of database to connect to (oracle, postgres, mysql)'), + cmd2.make_option('-H', '--host', type='string', + help='Host to connect to (postgresql only)'), + cmd2.make_option('-p', '--port', type='int', + help='Port to connect to (postgresql only)'), + cmd2.make_option('-d', '--database', type='string', + help='Database name to connect to'), + cmd2.make_option('-U', '--username', type='string', + help='Database user name to connect as') + ]) def do_connect(self, arg, opts): + '''Opens the DB connection''' if opts.closeall: self.closeall() @@ -175,7 +207,10 @@ try: connect_info = self.url_connect(arg) except sqlalchemy.exc.ArgumentError, e: - connect_info = self.ora_connect(arg) + connect_info = self.url_connect(self.connect_url(arg, opts)) + except Exception, e: + self.perror(r'URL connection format: rdbms://username:password@host/database') + return if opts.add or (self.connection_number is None): try: self.connection_number = max(self.connections.keys()) + 1 @@ -277,8 +312,7 @@ terminatorSearchString = '|'.join('\\' + d.split()[0] for d in do_terminators.__doc__.splitlines()) bindScanner = {'oracle': Parser(pyparsing.Literal(':') + pyparsing.Word( pyparsing.alphanums + "_$#" )), - 'postgres': Parser(pyparsing.Literal('%(') + - pyparsing.Word(pyparsing.alphanums + "_$#") + ')s')} + 'postgres': Parser(pyparsing.Literal('%(') + legal_sql_word + ')s')} def findBinds(self, target, givenBindVars = {}): result = givenBindVars if self.rdbms in self.bindScanner: