changeset 409:5b88ce5f31ff

ugh, trying to separate -- comments from --flags
author catherine@DellZilla
date Thu, 15 Oct 2009 17:39:54 -0400
parents 80413ef3699a
children 22c96b70ee41
files sqlpython/mysqlpy.py sqlpython/sqlpyPlus.py sqlpython/sqlpython.py
diffstat 3 files changed, 42 insertions(+), 39 deletions(-) [+]
line wrap: on
line diff
--- a/sqlpython/mysqlpy.py	Thu Oct 15 14:30:32 2009 -0400
+++ b/sqlpython/mysqlpy.py	Thu Oct 15 17:39:54 2009 -0400
@@ -185,37 +185,37 @@
 def run():
     my=mysqlpy()
     print my.__doc__
-    # Arguments to sqlpython: <connection info> (optional) 
-    try:
-        if sys.argv[1][0] != '@':
-            connectstring = sys.argv.pop(1)
-            if len(sys.argv) >= 3 and sys.argv[1].lower() == 'as': # attach AS SYSDBA or AS SYSOPER if present
-                for i in (1,2):
-                    connectstring += ' ' + sys.argv.pop(1)
-            my.do_connect(connectstring)
-        for arg in sys.argv[1:]:
-            if my.onecmd(arg + '\n') == my._STOP_AND_EXIT:
-                return
-    except IndexError:
-        pass
+    # split a complex argument string, like 
+    # ``--postgres -H localhost dbname username ls "select * from tbl" @myscript``
+    # into a portion to feed to ``do_connect`` and a portion to run as SQL commands
+    connectstring = sys.argv[1:]
+    commands = []
+    for (n, arg) in enumerate(sys.argv[1:]):
+        if arg.startswith('@') or len(arg.split()) > 1:
+            connectstring = sys.argv[1:n+1]
+            commands = sys.argv[n+1:]
+            break
+    if connectstring:
+        my.onecmd('connect %s' % ' '.join(connectstring))
+    for command in commands:
+        if my.onecmd(command + '\n') == my._STOP_AND_EXIT:
+            return
     my.cmdloop()
     
 class TestCase(Cmd2TestCase):
     CmdApp = mysqlpy
 
-if __name__ == '__main__':
-    parser = optparse.OptionParser()
-    parser.add_option('-t', '--test', dest='unittests', action='store_true', default=False, help='Run unit test suite')
-    try:
-        (callopts, callargs) = parser.parse_args()
-        if callopts.unittests:
-            mysqlpy.testfiles = callargs
-            sys.argv = [sys.argv[0]]  # the --test argument upsets unittest.main()
-            unittest.main()
-    except optparse.BadOptionError:
-        pass        
+if __name__ == '__main__':    
+    testfiles = sys.argv[1:]
+    for arg in ('-t', '--test'):
+        if arg in testfiles:
+            testfiles.remove(arg)
+            mysqlpy.testfiles = testfiles
+            if not testfiles:
+                print 'No test file specified to run against!'
+                sys.exit()
+    if hasattr(mysqlpy, 'testfiles'):
+        sys.argv = [sys.argv[0]]
+        unittest.main()
     else:
-        #import cProfile, pstats
-        #cProfile.run('run()', 'stats.txt')
-        run()
-        
+        run()
\ No newline at end of file
--- a/sqlpython/sqlpyPlus.py	Thu Oct 15 14:30:32 2009 -0400
+++ b/sqlpython/sqlpyPlus.py	Thu Oct 15 17:39:54 2009 -0400
@@ -348,8 +348,8 @@
     multilineCommands = '''select insert update delete tselect
                       create drop alter _multiline_comment'''.split()
     sqlpython.sqlpython.noSpecialParse.append('spool')
-    commentGrammars = pyparsing.Or([pyparsing.Literal('--') + pyparsing.restOfLine, pyparsing.cStyleComment])
-    commentGrammars = pyparsing.Or([Parser.comment_def, pyparsing.cStyleComment])
+    commentGrammars = pyparsing.cStyleComment
+    multilineOnlyCommentGrammars = pyparsing.Literal('--') + pyparsing.restOfLine  
     prefixParser = pyparsing.Optional(pyparsing.Word(pyparsing.nums)('connection_number') 
                                       + ':')
     reserved_words = [
--- a/sqlpython/sqlpython.py	Thu Oct 15 14:30:32 2009 -0400
+++ b/sqlpython/sqlpython.py	Thu Oct 15 17:39:54 2009 -0400
@@ -13,7 +13,7 @@
 __version__ = '1.6.8'    
 
 class Parser(object):
-    comment_def = "--" + ~ ('-' + pyparsing.CaselessKeyword('begin')) + pyparsing.ZeroOrMore(pyparsing.CharsNotIn("\n"))    
+    comment_def = "--" + pyparsing.NotAny('-' + pyparsing.CaselessKeyword('begin')) + pyparsing.ZeroOrMore(pyparsing.CharsNotIn("\n"))    
     def __init__(self, scanner, retainSeparator=True):
         self.scanner = scanner
         self.scanner.ignore(pyparsing.sglQuotedString)
@@ -119,7 +119,7 @@
                                                (pyparsing.CaselessKeyword('sysoper') ^ 
                                                 pyparsing.CaselessKeyword('sysdba'))('mode')))
     postgresql_connect_parser = (legal_sql_word('db_name') + 
-                                 pyparsing.Optional(legal_sql_word('username')))                       
+                                 pyparsing.Optional(legal_sql_word('username')))
           
     def connect_url(self, arg, opts):
         rdbms = opts.rdbms or self.default_rdbms
@@ -170,9 +170,9 @@
                                     help='close connection {N} (or current)'),
                    cmd2.make_option('-C', '--closeall', action='store_true', 
                                     help='close all connections'),
-                   cmd2.make_option('--postgres', help='Connect to a postgreSQL database'),
-                   cmd2.make_option('--oracle', help='Connect to an Oracle database'),
-                   cmd2.make_option('--mysql', help='Connect to a MySQL database'),                   
+                   cmd2.make_option('--postgres', action='store_true', help='Connect to a postgreSQL database'),
+                   cmd2.make_option('--oracle', action='store_true', help='Connect to an Oracle database'),
+                   cmd2.make_option('--mysql', action='store_true', help='Connect to a MySQL database'),                   
                    cmd2.make_option('-r', '--rdbms', type='string', 
                                     help='Type of database to connect to (oracle, postgres, mysql)'),
                    cmd2.make_option('-H', '--host', type='string', 
@@ -189,15 +189,15 @@
         '''Opens the DB connection'''
         if opts.closeall:
             self.closeall()
-            return
+            return 
         if opts.close:
             if not arg:
                 arg = self.connection_number
             self.disconnect(arg)
-            return
+            return 
         if not arg:
             self.list_connections()
-            return
+            return 
         try:
             if self.successful_connection_to_number(arg):
                 return
@@ -207,7 +207,8 @@
         try:
             connect_info = self.url_connect(arg)
         except sqlalchemy.exc.ArgumentError, e:
-            connect_info = self.url_connect(self.connect_url(arg, opts))
+            url = self.connect_url(arg, opts)
+            connect_info = self.url_connect(url)
         except Exception, e:
             self.perror(r'URL connection format: rdbms://username:password@host/database')
             return
@@ -224,6 +225,8 @@
             self.curs.callproc('dbms_output.enable', [])
         if (self.rdbms == 'mysql'):
             self.curs.execute('SET SQL_MODE=ANSI')
+        return 
+    
     def postparsing_precmd(self, statement):
         stop = 0
         self.saved_connection_number = None