changeset 407:188c86d4a11e

struggling with option parsing
author catherine@DellZilla
date Wed, 14 Oct 2009 13:57:29 -0400
parents 4b481c4293b8
children 80413ef3699a
files sqlpython/completion.py sqlpython/schemagroup.py sqlpython/sqlpyPlus.py sqlpython/sqlpython.py
diffstat 4 files changed, 168 insertions(+), 82 deletions(-) [+]
line wrap: on
line diff
--- a/sqlpython/completion.py	Sun Oct 11 10:03:18 2009 -0400
+++ b/sqlpython/completion.py	Wed Oct 14 13:57:29 2009 -0400
@@ -1,4 +1,4 @@
-import pyparsing, re
+import pyparsing, re, doctest
 
 sqlStyleComment = pyparsing.Literal("--") + pyparsing.ZeroOrMore(pyparsing.CharsNotIn("\n"))
 keywords = {'order by': pyparsing.Keyword('order', caseless=True) +
@@ -40,30 +40,31 @@
     results.sort(cmp=lambda x,y:cmp(x[1],y[1]))
     return results
         
-at_beginning = re.compile(r'^\s*\S+\s*$')
+at_beginning = re.compile(r'^\s*\S+$')
 def whichSegment(statement):
-    if at_beginning.search(statement):
+    '''
+    >>> whichSegment("SELECT col FROM t")
+    'from'
+    >>> whichSegment("SELECT * FROM t")
+    'from'
+    >>> whichSegment("DESC ")
+    'DESC'
+    >>> whichSegment("DES")
+    'beginning'
+    >>> whichSegment("")
+    'beginning'
+    >>> whichSegment("select  ")
+    'select'
+    
+    '''
+    if (not statement) or at_beginning.search(statement):
         return 'beginning'
     results = orderedParseResults(keywords.values(), statement)
     if results:
         return ' '.join(results[-1][0])
     else:
-        return None
-
-oracleIdentifierCharacters = pyparsing.alphanums + '_#$'    
-def wordInProgress(statement):
-    result = []
-    letters = list(statement)
-    letters.reverse()
-    for letter in letters:
-        if letter not in oracleIdentifierCharacters:
-            result.reverse()
-            return ''.join(result)
-        result.append(letter)
-    result.reverse()
-    return ''.join(result)
-    
-
+        return statement.split(None,1)[0]
+   
 reserved = '''
       access
      add
@@ -173,4 +174,7 @@
      view
      whenever
      where
-     with '''.split()
\ No newline at end of file
+     with '''.split()
+
+if __name__ == '__main__':
+    doctest.testmod()
--- a/sqlpython/schemagroup.py	Sun Oct 11 10:03:18 2009 -0400
+++ b/sqlpython/schemagroup.py	Wed Oct 14 13:57:29 2009 -0400
@@ -47,11 +47,40 @@
                                [s.qual_table_names for s in self.schemas.values()],
                                [])
         
+class OracleSchemaAccess(object):        
+    child_type = gerald.OracleSchema
+    current_database_time_query = 'SELECT sysdate FROM dual'
+    def latest_ddl_timestamp_query(self, username, connection):
+        curs = connection.cursor()
+        curs.execute('''SELECT   owner, MAX(last_ddl_time)
+                        FROM     all_objects
+                        GROUP BY owner
+                        -- sort :username to top
+                        ORDER BY REPLACE(owner, :username, 'A'), owner''',
+                     {'username': username.upper()})
+        return curs 
+
+class PostgresSchemaAccess(object):        
+    child_type = gerald.PostgresSchema
+    current_database_time_query = 'SELECT current_time'
+    def latest_ddl_timestamp_query(self, username, connection):
+        curs = connection.cursor()
+        curs.execute("""SELECT  '%s', current_time""" % username)
+        return curs 
+    
+class MySQLSchemaAccess(object):        
+    child_type = gerald.MySQLSchema
+    current_database_time_query = 'SELECT sysdate FROM dual'
+    def latest_ddl_timestamp_query(self, username, connection):
+        curs = connection.cursor()
+        curs.execute("""SELECT  '%s', current_time""" % username)
+        return curs 
+    
 class SchemaDict(dict):
-    schema_types = {'oracle': gerald.OracleSchema}
+    schema_types = {'oracle': OracleSchemaAccess, 'postgres': PostgresSchemaAccess, 'mysql': MySQLSchemaAccess}
     def __init__(self, dct, rdbms, user, connection, connection_string):
         dict.__init__(self, dct)
-        self.child_type = self.schema_types[rdbms]
+        self.schema_access = self.schema_types[rdbms]()
         self.user = user
         self.connection = connection
         self.gerald_connection_string = gerald_connection_string(connection_string)
@@ -59,19 +88,24 @@
         self.complete = 0
     def refresh_asynch(self):
         self.refresh_thread.start()
+    current_database_time_sql = {gerald.OracleSchema: 'SELECT sysdate FROM dual',
+                                 gerald.PostgresSchema: 'SELECT current_time'}
     def get_current_database_time(self):
         curs = self.connection.cursor()
-        curs.execute('SELECT sysdate FROM dual')
+        curs.execute(self.schema_access.current_database_time_query)
         return curs.fetchone()[0]              
+    def refresh_times(self, target_schema):
+        now = self.get_current_database_time()
+        result = []
+        for (schema_name, schema) in self.items():
+            if (not target_schema) or (target_schema.lower() == schema_name.lower()):
+                result.append('%s: %s  (%s ago)' % (schema_name, schema.refreshed, now - schema.refreshed))
+        result.sort()
+        return '\n'.join(result)
+            
     def refresh(self):
         current_database_time = self.get_current_database_time()
-        curs = self.connection.cursor()
-        curs.execute('''SELECT   owner, MAX(last_ddl_time)
-                        FROM     all_objects
-                        GROUP BY owner
-                        -- sort :username to top
-                        ORDER BY REPLACE(owner, :username, 'A'), owner''',
-                     {'username': self.user.upper()})
+        curs = self.schema_access.latest_ddl_timestamp_query(self.user, self.connection)
         for (owner, last_ddl_time) in curs.fetchall():
             if (owner not in self) or (self[owner].refreshed < last_ddl_time):
                 self.refresh_one(owner, current_database_time)
@@ -81,11 +115,12 @@
         self.column_names = [s.column_names for s in self.values()]
         self.columns = reduce(operator.add, [s.column_names for s in self.values()])
         self.complete = 'all'
-        print 'metadata discovered'
     def refresh_one(self, owner, current_database_time=None):
+        #owner = owner.upper()
+        owner = str(owner)
         if not current_database_time:
             current_database_time = self.get_current_database_time()
-        self[owner] = self.child_type(owner, self.gerald_connection_string)
+        self[owner] = self.schema_access.child_type(owner, self.gerald_connection_string)
         self[owner].refreshed = current_database_time        
         build_column_list(self[owner])
 
--- a/sqlpython/sqlpyPlus.py	Sun Oct 11 10:03:18 2009 -0400
+++ b/sqlpython/sqlpyPlus.py	Wed Oct 14 13:57:29 2009 -0400
@@ -329,6 +329,16 @@
         else:
             return '(BLOB not saved, check bloblimit)'
         
+class Abbreviatable_List(list):
+    def match(self, target):
+        target = target.lower()
+        result = [i for i in self if i.startswith(target)]
+        if len(result) == 0:
+            raise ValueError, 'None of %s start with %s' % (str(self), target)
+        elif len(result) > 1:
+            raise ValueError, 'Too many matches: %s' % str(result)
+        return result[0]
+    
 class sqlpyPlus(sqlpython.sqlpython):
     defaultExtension = 'sql'
     abbrev = True    
@@ -351,7 +361,8 @@
     def __init__(self):
         sqlpython.sqlpython.__init__(self)
         self.binds = CaselessDict()
-        self.settable += '''autobind bloblimit colors commit_on_exit maxfetch maxtselctrows 
+        self.settable += '''autobind bloblimit colors commit_on_exit 
+                            default_rdbms maxfetch maxtselctrows 
                             rows_remembered scan serveroutput 
                             sql_echo timeout heading wildsql version'''.split()
         self.settable.remove('case_insensitive')
@@ -369,6 +380,8 @@
         self.result_history = []
         self.rows_remembered = 10000
         self.bloblimit = 5
+        self.default_rdbms = 'oracle'
+        self.rdbms_supported = Abbreviatable_List('oracle postgres mysql'.split())
         self.version = 'SQLPython %s' % sqlpython.__version__
         self.pystate = {'r': [], 'binds': self.binds, 'substs': self.substvars}
         
@@ -546,7 +559,6 @@
         (username, schemas) = self.metadata()
         segment = completion.whichSegment(line)
         text = text.upper()
-        print segment
         if segment in ('select', 'where', 'having', 'set', 'order by', 'group by'):
             completions = [c for c in schemas[username].column_names if c.startswith(text)] \
                           or [c for c in schemas.qual_column_names if c.startswith(text)]
@@ -556,12 +568,11 @@
         elif segment == 'beginning':
             completions = [n for n in self.get_names() if n.startswith('do_')] + [
                            'insert', 'update', 'delete', 'drop', 'alter', 'begin', 'declare', 'create']
-            print completions
             completions = [c for c in completions if c.startswith(text)]     
+        elif segment:
+            completions = [t for t in schemas[username].table_names if t.startswith(text)]
         else:
             completions = [r for r in completion.reserved if r.startswith(text)]
-                            
-                           
         return completions
     
     columnlistPattern = pyparsing.SkipTo(pyparsing.CaselessKeyword('from'))('columns') + \
@@ -1522,10 +1533,12 @@
               make_option('-c', '--check', action='store_true', help="Don't refresh, just check refresh status")])
     def do_refresh(self, arg, opts):
         '''Refreshes metadata for the specified schema; only required
-           if table structures, etc. have changed. '''
+           if table structures, etc. have changed. (sqlpython will check
+           for new objects, and will not waste labor if no objects have
+           been created or modified in a schema.)'''
         (username, schemas) = self.metadata()
         if opts.check:
-            print schemas.complete
+            self.poutput(schemas.refresh_times(arg))
             return
         if opts.all:
             if opts.immediate:
--- a/sqlpython/sqlpython.py	Sun Oct 11 10:03:18 2009 -0400
+++ b/sqlpython/sqlpython.py	Wed Oct 14 13:57:29 2009 -0400
@@ -107,53 +107,85 @@
         s = conn['schemas']
         s.refresh_asynch()
         return conn
-    def ora_connect(self, arg):
-        modeval = 0
-        oraserv = None
-        for modere, modevalue in self.connection_modes.items():
-            if modere.search(arg):
-                arg = modere.sub('', arg)
-                modeval = modevalue
-        try:
-            orauser, oraserv = arg.split('@')
-        except ValueError:
-            try:
-                oraserv = os.environ['ORACLE_SID']
-            except KeyError:
-                self.perror('instance not specified and environment variable ORACLE_SID not set')
-                return
-            orauser = arg
-        sid = oraserv
-        try:
-            host, sid = oraserv.split('/')
-            try:
-                host, port = host.split(':')
-                port = int(port)
-            except ValueError:
-                port = 1521
-            oraserv = cx_Oracle.makedsn(host, port, sid)
-        except ValueError:
-            pass
-        try:
-            orauser, orapass = orauser.split('/')
-        except ValueError:
-            orapass = getpass.getpass('Password: ')
-        if orauser.upper() == 'SYS' and not modeval:
-            self.pfeedback('Privilege not specified for SYS, assuming SYSOPER')
-            modeval = cx_Oracle.SYSOPER
-        result = self.url_connect('oracle://%s:%s@%s/?mode=%d' % (orauser, orapass, oraserv, modeval))
-        result['dbname'] = oraserv
-        return result
     
-    connection_modes = {re.compile(' AS SYSDBA', re.IGNORECASE): cx_Oracle.SYSDBA, 
-                        re.compile(' AS SYSOPER', re.IGNORECASE): cx_Oracle.SYSOPER}
+    legal_sql_word = pyparsing.Word(pyparsing.alphanums + '_$#')
+    legal_hostname = pyparsing.Word(pyparsing.alphanums + '_-.')('host') + pyparsing.Optional(
+        ':' + pyparsing.Word(pyparsing.nums)('port'))
+    oracle_connect_parser = legal_sql_word('username') + (
+                            pyparsing.Optional('/' + pyparsing.CharsNotIn('@')("password")) + 
+                            pyparsing.Optional('@' + pyparsing.Optional(legal_hostname + '/') +
+                                               legal_sql_word('db_name')) + 
+                            pyparsing.Optional(pyparsing.CaselessKeyword('as') + 
+                                               (pyparsing.CaselessKeyword('sysoper') ^ 
+                                                pyparsing.CaselessKeyword('sysdba'))('mode')))
+    postgresql_connect_parser = (legal_sql_word('db_name') + 
+                                 pyparsing.Optional(legal_sql_word('username')))                       
+          
+    def connect_url(self, arg, opts):
+        rdbms = opts.rdbms or self.default_rdbms
+        
+        mode = 0
+        host = None
+        port = None
+        
+        if rdbms == 'oracle':
+            result = self.oracle_connect_parser.parseString(arg)
+            if result.mode == 'sysdba':
+                mode = cx_Oracle.SYSDBA
+            elif result.mode == 'sysoper':
+                mode = cx_Oracle.SYSOPER   
+            else:
+                mode = 0
+        elif rdbms == 'postgres':
+            result = self.postgresql_connect_parser.parseString(arg)
+            port = opts.port or os.environ.get('PGPORT') or 5432            
+            host = opts.host or os.environ.get('PGHOST') or 'localhost'
+       
+        username = result.username or opts.username           
+        if not username and rdbms == 'postgres':
+            username = os.environ.get('PGUSER') or os.environ.get('USER')
+
+        db_name = result.db_name or opts.database
+        if not db_name:
+            if rdbms == 'oracle':
+                db_name = os.environ.get('ORACLE_SID')
+            elif rdbms == 'postgres':
+                db_name = os.environ.get('PGDATABASE') or username
+        
+        password = result.password or getpass.getpass('Password: ')
+               
+        if host:
+            if port:
+                host = '%s:%s' % (host, port)
+            db_name = '%s/%s' % (host, db_name)
+
+        url = '%s://%s:%s@%s' % (rdbms, username, password, db_name)
+        if mode:
+            url = '%s/?mode=%d' % mode
+        return url
+    
     @cmd2.options([cmd2.make_option('-a', '--add', action='store_true', 
                                     help='add connection (keep current connection)'),
                    cmd2.make_option('-c', '--close', action='store_true', 
                                     help='close connection {N} (or current)'),
                    cmd2.make_option('-C', '--closeall', action='store_true', 
-                                    help='close all connections'),])
+                                    help='close all connections'),
+                   cmd2.make_option('--postgres', help='Connect to a postgreSQL database'),
+                   cmd2.make_option('--oracle', help='Connect to an Oracle database'),
+                   cmd2.make_option('--mysql', help='Connect to a MySQL database'),                   
+                   cmd2.make_option('-r', '--rdbms', type='string', 
+                                    help='Type of database to connect to (oracle, postgres, mysql)'),
+                   cmd2.make_option('-H', '--host', type='string', 
+                                    help='Host to connect to (postgresql only)'),                                  
+                   cmd2.make_option('-p', '--port', type='int', 
+                                    help='Port to connect to (postgresql only)'),                                  
+                   cmd2.make_option('-d', '--database', type='string', 
+                                    help='Database name to connect to'),
+                   cmd2.make_option('-U', '--username', type='string', 
+                                    help='Database user name to connect as')
+                   ])
     def do_connect(self, arg, opts):
+ 
         '''Opens the DB connection'''
         if opts.closeall:
             self.closeall()
@@ -175,7 +207,10 @@
         try:
             connect_info = self.url_connect(arg)
         except sqlalchemy.exc.ArgumentError, e:
-            connect_info = self.ora_connect(arg)
+            connect_info = self.url_connect(self.connect_url(arg, opts))
+        except Exception, e:
+            self.perror(r'URL connection format: rdbms://username:password@host/database')
+            return
         if opts.add or (self.connection_number is None):
             try:
                 self.connection_number = max(self.connections.keys()) + 1
@@ -277,8 +312,7 @@
     terminatorSearchString = '|'.join('\\' + d.split()[0] for d in do_terminators.__doc__.splitlines())
         
     bindScanner = {'oracle': Parser(pyparsing.Literal(':') + pyparsing.Word( pyparsing.alphanums + "_$#" )),
-                   'postgres': Parser(pyparsing.Literal('%(') + 
-                                      pyparsing.Word(pyparsing.alphanums + "_$#") + ')s')}
+                   'postgres': Parser(pyparsing.Literal('%(') + legal_sql_word + ')s')}
     def findBinds(self, target, givenBindVars = {}):
         result = givenBindVars
         if self.rdbms in self.bindScanner: