changeset 515:86d5408e596b

trying to merge with rev 507
author catherine.devlin@gmail.com
date Tue, 02 Nov 2010 04:32:19 -0400
parents ea547649a1b8
children 74255a272f62
files sqlpython/connections.py sqlpython/sqlpyPlus.py
diffstat 2 files changed, 86 insertions(+), 86 deletions(-) [+]
line wrap: on
line diff
--- a/sqlpython/connections.py	Tue Nov 02 04:14:00 2010 -0400
+++ b/sqlpython/connections.py	Tue Nov 02 04:32:19 2010 -0400
@@ -86,6 +86,8 @@
     database = None
     mode = 0
     connection_uri_parser = re.compile('(?P<rdbms>postgres|oracle|mysql|sqlite|mssql)://?(?P<connect_string>.*$)', re.IGNORECASE)    
+    oracle_style_connection_parser = re.compile('(?P<username>[^/\s@]*)(/(?P<password>[^/\s@]*))?(@((?P<hostname>[^/\s:]*)(:(?P<port>\d{1,4}))?/)?(?P<database>[^/\s:]*))?(\s+as\s+(?P<mode>sys(dba|oper)))?',
+                                     re.IGNORECASE)
     connection_parser = re.compile('((?P<database>\S+)(\s+(?P<username>\S+))?)?')    
     def __init__(self, arg, opts, default_rdbms = 'oracle'):
         'no docstring'
@@ -117,8 +119,10 @@
         self.default_rdbms = default_rdbms
         self.determine_rdbms()  # may be altered later as connect string is parsed
         if not self.parse_connect_uri(arg):
-            self.set_defaults()        
+            self.set_defaults()       
             connectargs = self.connection_parser.search(self.arg)
+            if '@' in connectargs.group('database'):
+                connectargs = OracleInstance.connection_parser.search(self.arg)
             if connectargs:
                 for param in ('username', 'password', 'database', 'port', 'hostname', 'mode'):
                     if hasattr(opts, param) and getattr(opts, param):
@@ -176,36 +180,9 @@
                        pyparsing.Optional(sqlname("owner") + ".") +
                        pyparsing.Optional(sqlname("name")) +
                        pyparsing.stringEnd ))
-    def parse_identifier(self, identifier):
-        """
-        >>> opts = OptionTestDummy(postgres=True, password='password')        
-        >>> db = DatabaseInstance('thedatabase theuser', opts)
-        >>> result = db.parse_identifier('scott.pets')
-        >>> (result.owner, result.type, result.name)
-        ('scott', '%', 'pets')
-        >>> result = db.parse_identifier('pets')
-        >>> (result.owner, result.type, result.name)
-        ('%', '%', 'pets')
-        >>> result = db.parse_identifier('pe*')
-        >>> (result.owner, result.type, result.name)
-        ('%', '%', 'pe%')
-        >>> result = db.parse_identifier('scott/table/pets')
-        >>> (result.owner, result.type, result.name)
-        ('scott', 'table', 'pets')
-        >>> result = db.parse_identifier('table/scott.pets')
-        >>> (result.owner, result.type, result.name)
-        ('scott', 'table', 'pets')
-        >>> result = db.parse_identifier('')
-        >>> (result.owner, result.type, result.name)
-        ('%', '%', '%')
-        >>> result = db.parse_identifier('table/scott.*')
-        >>> (str(result.owner), str(result.type), str(result.name))
-        ('scott', 'table', '%')
-        """
-        identifier = self.sql_format_wildcards(identifier)
-        result = {'owner': '%', 'type': '%', 'name': '%'}
-        result.update(dict(self.ls_parser.parseString(identifier)))
-        return result 
+    identifier_regex = re.compile(
+                       r'((?P<object_type>DATABASE LINK|DIRECTORY|FUNCTION|INDEX|JOB|MATERIALIZED VIEW|PACKAGE|PROCEDURE|SEQUENCE|SYNONYM|TABLE|TRIGGER|TYPE|VIEW|BASE TABLE)($|[\\/.\s])+)?(?P<remainder>.*)',
+                       re.IGNORECASE)
     def comparison_operator(self, target):
         if ('%' in target) or ('_' in target):
             operator = 'LIKE'
@@ -265,8 +242,10 @@
     gerald_types = {'TABLE': gerald.oracle_schema.Table,
                     'VIEW': gerald.oracle_schema.View}
     def object_metadata(self, owner, object_type, name):
-        return self.gerald_types[object_type](name, self.connection.cursor(), owner)
-                      
+        if object_type in self.gerald_types:
+            return self.gerald_types[object_type](name, self.connection.cursor(), owner)
+        else:
+            raise NotImplementedError, '%s not implemented for this RDBMS' % object_type
 
 parser = optparse.OptionParser()
 parser.add_option('--postgres', action='store_true', help='Connect to postgreSQL: `connect --postgres [DBNAME [USERNAME]]`')
@@ -423,11 +402,15 @@
         'Puts a tuple of (name, value) pairs into the bind format desired by cx_Oracle'
         return dict((b[0], b[1].upper()) for b in binds)
     gerald_types = {'TABLE': gerald.oracle_schema.Table,
-                    'VIEW': gerald.oracle_schema.View}
+                    'VIEW': gerald.oracle_schema.View,
+                    'TRIGGER': gerald.oracle_schema.Trigger,
+                    'SEQUENCE': gerald.oracle_schema.Sequence,
+                    'PACKAGE': lambda name, cursor, owner: gerald.oracle_schema.Package(name, 'PACKAGE', cursor, owner),
+                    'DATABASE LINK': gerald.oracle_schema.DatabaseLink,
+                    'FUNCTION': lambda name, cursor, owner: gerald.oracle_schema.CodeObject(name, 'FUNCTION', cursor, owner),
+                    'PROCEDURE': lambda name, cursor, owner: gerald.oracle_schema.CodeObject(name, 'PROCEDURE', cursor, owner),
+                    }
 
                 
 if __name__ == '__main__':
-    opts = OptionTestDummy(password='password')
-    db = DatabaseInstance('oracle://system:twttatl@orcl', opts)
-    print list(db.findAll(''))
-    #doctest.testmod()
+    doctest.testmod()
--- a/sqlpython/sqlpyPlus.py	Tue Nov 02 04:14:00 2010 -0400
+++ b/sqlpython/sqlpyPlus.py	Tue Nov 02 04:32:19 2010 -0400
@@ -419,7 +419,7 @@
         self.default_rdbms = 'oracle'
         self.rdbms_supported = Abbreviatable_List('oracle postgres mysql'.split())
         self.version = 'SQLPython %s' % sqlpython.__version__
-        self.pystate = {'r': [], 'binds': self.binds, 'substs': self.substvars}
+        self.pystate = {'r': [], 'binds': self.binds, 'substs': self.substvars, 'sql': self.onecmd_plus_hooks}
         
     # overrides cmd's parseline
     def parseline(self, line):
@@ -592,13 +592,11 @@
         segment = completion.whichSegment(line)        
         text = text.upper()
         if segment in ('select', 'where', 'having', 'set', 'order by', 'group by'):
-            completions = [c for c in schemas[username].column_names if c.startswith(text)] \
-                          or [c for c in schemas.qual_column_names if c.startswith(text)]
-                          # TODO: the latter not working
+            completions = [c[3] for c in self.current_instance.columns(text + '%', None)]
         elif segment in ('from', 'update', 'insert into'):
             # print schemas[username].table_names
             # TODO: from postgres, these table names are jrrt.fishies, etc.
-            completions = [t for t in schemas[username].table_names if t.startswith(text)]
+            completions = [object_name for (owner, object_type, object_name, synonym_name) in self.current_instance.objects(text + '%', None)]
         elif segment == 'beginning':
             completions = [n for n in self.get_names() if n.startswith('do_')] + [
                            'insert', 'update', 'delete', 'drop', 'alter', 'begin', 'declare', 'create']
@@ -731,6 +729,14 @@
         if self.scan:
             raw = self.ampersand_substitution(raw, regexpr=self.singleampre, isglobal=False)
         return raw
+    def postparse(self, parseResult):
+        if (not parseResult.command):
+            try:
+                connection_number = int(parseResult.instance_number)
+                parseResult = self.parser.parseString('connect %d' % connection_number)
+            except (TypeError, ValueError):
+                return parseResult
+        return parseResult
     
     rowlimitPattern = pyparsing.Word(pyparsing.nums)('rowlimit')
     terminators = '; \\C \\t \\i \\p \\l \\L \\b \\r'.split() + output_templates.keys()
@@ -856,33 +862,37 @@
         self._pull(arg, opts)
 
     def _pull(self, arg, opts, vc=None):  
+        opts.all = True
         statekeeper = Statekeeper(opts.dump and self, ('stdout',))
         try:
-            for (owner, object_type, name) in self.current_instance.objects(arg, opts):
+            for (owner, object_type, name, synonym_name) in self.current_instance.objects(arg, opts):
                 obj = self.current_instance.object_metadata(owner, object_type, name)
-                txt = obj.get_ddl()
-                if opts.get('lines'):
-                    txt = self._with_line_numbers(txt)    
-                if opts.dump:
-                    path = os.path.join(owner.lower(), object_type.lower()).replace(' ', '_')
-                    try:
-                        os.makedirs(path)
-                    except OSError:
-                        pass
-                    filename = os.path.join(path, '%s.sql' % name.lower())
-                    self.stdout = open(filename, 'w')
-                if opts.get('num') is not None:
-                    txt = txt.splitlines()
-                    txt = centeredSlice(txt, center=opts.num+1, width=opts.width)
-                    txt = '\n'.join(txt)
-                else:
-                    txt = 'REMARK BEGIN %s/%s/%s\n%s\nREMARK END\n' % (owner, object_type, name, txt)
-                self.poutput(txt)
-                if opts.dump:
-                    self.stdout.close()
-                    statekeeper.restore()
-                    if vc:
-                        subprocess.call(vc + [filename])                    
+                txts = [(object_type, obj.get_ddl())]
+                if hasattr(obj, 'get_body_ddl'):
+                    txts.append(('PACKAGE BODY', obj.get_body_ddl()))
+                for (object_type, txt) in txts:
+                    if opts.get('lines'):
+                        txt = self._with_line_numbers(txt)    
+                    if opts.dump:
+                        path = os.path.join(owner.lower(), object_type.lower()).replace(' ', '_')
+                        try:
+                            os.makedirs(path)
+                        except OSError:
+                            pass
+                        filename = os.path.join(path, '%s.sql' % name.lower())
+                        self.stdout = open(filename, 'w')
+                    if opts.get('num') is not None:
+                        txt = txt.splitlines()
+                        txt = centeredSlice(txt, center=opts.num+1, width=opts.width)
+                        txt = '\n'.join(txt)
+                    else:
+                        txt = 'REMARK BEGIN\n%s\nREMARK END\n/\n\n' % txt
+                    self.poutput(txt)
+                    if opts.dump:
+                        self.stdout.close()
+                        statekeeper.restore()
+                        if vc:
+                            subprocess.call(vc + [filename])                    
         except:
             statekeeper.restore()
             raise
@@ -1012,10 +1022,10 @@
         """Finds argument in source code or (with -c) in column definitions."""
         if opts.col:
             for (owner, object_type, table_name, column_name) in self.current_instance.columns(arg, opts):
-                self.poutput('%s %s.%s.%s' % (object_type, owner, table_name, column_name))
+                self.poutput('%s.%s' % (self.object_label(object_type, owner, table_name, None), column_name))
         else:
             for (owner, object_type, name, line_number, txt) in self.current_instance.source(arg, opts):
-                self.poutput('%s %s.%s %d: %s' % (object_type, owner, name, line_number, txt))
+                self.poutput('%s %d: %s' % (self.object_label(object_type, owner, name, None), line_number, txt))
            
     def _col_type_descriptor(self, col):
         #if col['type'] in ('integer',):
@@ -1040,18 +1050,22 @@
               all_users_option,
               make_option('-l', '--long', action='store_true', help='long descriptions'),
               make_option('-r', '--reverse', action='store_true', help="Reverse order while sorting")]
-        
-    @options(standard_options + [
-              make_option('-A', '--alpha', action='store_true', help='List columns alphabetically')])
+       
+    def object_label(self, object_type, owner, name, synonym_name):
+        return '%s %s.%s%s' % (object_type, owner, name, synonym_name and (synonym_name != name) and ' ("%s")' % synonym_name or '')
+    @options([make_option('-l', '--long', action='store_true', help='long descriptions'),
+              make_option('-r', '--reverse', action='store_true', help="Reverse order while sorting"),
+              make_option('-a', '--alpha', action='store_true', help='List columns alphabetically')])
     def do_describe(self, arg, opts):
+        opts.all = True
         rowlimit = self.rowlimit(arg)
         if opts.alpha:
             sortkey = operator.itemgetter('name')
         else:
             sortkey = operator.itemgetter('sequence')
-        for (owner, object_type, name) in self.current_instance.objects(arg, opts):
+        for (owner, object_type, name, synonym_name) in self.current_instance.objects(arg, opts):
             obj = self.current_instance.object_metadata(owner, object_type, name)
-            self.tblname = '%s %s.%s' % (object_type, owner, name)
+            self.tblname = self.object_label(object_type, owner, name, synonym_name)
             self.pfeedback(self.tblname)
             if opts.long and hasattr(obj, 'comments') and obj.comments:
                 self.poutput(obj.comments) 
@@ -1084,24 +1098,27 @@
                     if end_heading.search(line):
                         break
                 self.poutput(''.join(l for (ln, l) in obj.source[:index]))
-    @options([all_users_option])            
+    @options([])            
     def do_deps(self, arg, opts):
+        opts.all = True
         '''Lists indexes, constraints, and triggers depending on an object'''
         #TODO: doesn't account for views; don't know about primary keys
-        for (owner, object_type, name) in self.current_instance.objects(arg, opts):
+        for (owner, object_type, name, synonym_name) in self.current_instance.objects(arg, opts):
             obj = self.current_instance.object_metadata(owner, object_type, name)
+            self.poutput(self.object_label(object_type, owner, name, synonym_name))
             for deptype in ('indexes', 'constraints', 'triggers'):
                 if hasattr(obj, deptype):
                     for (depname, depobj) in getattr(obj, deptype).items():
                         self.poutput('%s %s' % (deptype, depname))
                 
-    @options([all_users_option])        
+    @options([])        
     def do_comments(self, arg, opts):
+        opts.all = True
         'Prints comments on a table and its columns.'
-        for (owner, object_type, name) in self.current_instance.objects(arg, opts):
+        for (owner, object_type, name, synonym_name) in self.current_instance.objects(arg, opts):
             obj = self.current_instance.object_metadata(owner, object_type, name)
-            if hasattr(obj, 'comments'):
-                self.poutput('%s %s.%s' % object_type, owner, name)
+            if hasattr(obj, 'comments'):  
+                self.poutput(self.object_label(object_type, owner, name, synonym_name))
                 self.poutput(obj.comments)
                 if hasattr(obj, 'columns'):
                     columns = obj.columns.values()
@@ -1278,10 +1295,10 @@
         
     def do__dir_(self, arg, opts, plural_name, str_function):
         long = opts.get('long')
-        for (owner, object_type, name) in self.current_instance.objects(arg, opts):
+        for (owner, object_type, name, synonym_name) in self.current_instance.objects(arg, opts):
             obj = self.current_instance.object_metadata(owner, object_type, name)
             if hasattr(obj, plural_name):
-                self.pfeedback('%s on %s' % (plural_name.title(), '%s %s.%s' % (object_type, owner, name)))
+                self.pfeedback('%s on %s' % (plural_name.title(), self.object_label(object_type, owner, name, synonym_name)))
                 result = [str_function(depobj, long) for depobj in getattr(obj, plural_name).values()]
                 result.sort(reverse=bool(opts.reverse))
                 self.poutput('\n'.join(result))
@@ -1519,8 +1536,8 @@
 
     def _do_ls(self, arg, opts):
         'Functional core of ``do_ls``, split out into an undecorated version to be callable from other methods'
-        for row in self.current_instance.objects(arg, opts):
-            self.poutput('%s/%s/%s' % row)
+        for (owner, type, name, synonym_name) in self.current_instance.objects(arg, opts):
+            self.poutput(self.object_label(type, owner, name, synonym_name))
                 
     @options(standard_options)
     def do_ls(self, arg, opts):
@@ -1552,9 +1569,9 @@
         re_pattern = re.compile(self._to_re_wildcards(pattern), 
                                 (opts.ignorecase and re.IGNORECASE) or 0)
         for target in targets:
-            for (owner, object_type, name) in self.current_instance.objects(target, opts):
+            for (owner, object_type, name, synonym_name) in self.current_instance.objects(target, opts):
                 obj = self.current_instance.object_metadata(owner, object_type, name)
-                self.pfeedback('%s %s.%s' % (object_type, owner, name))
+                self.pfeedback(self.object_label(object_type, owner, name, synonym_name))
                 if hasattr(obj, 'columns'):
                     clauses = []
                     for col in obj.columns: