changeset 400:8903d24575f0

desc works on code
author catherine@DellZilla
date Thu, 08 Oct 2009 15:08:44 -0400
parents 7fcb0ddc75a6
children 0cc91493a1d4
files sqlpython/sqlpyPlus.py sqlpython/sqlpython.py
diffstat 2 files changed, 61 insertions(+), 28 deletions(-) [+]
line wrap: on
line diff
--- a/sqlpython/sqlpyPlus.py	Thu Oct 08 11:13:41 2009 -0400
+++ b/sqlpython/sqlpyPlus.py	Thu Oct 08 15:08:44 2009 -0400
@@ -25,7 +25,7 @@
 """
 import sys, os, re, sqlpython, cx_Oracle, pyparsing, re, completion
 import datetime, pickle, binascii, subprocess, time, itertools, hashlib
-import traceback
+import traceback, operator
 from cmd2 import Cmd, make_option, options, Statekeeper, Cmd2TestCase
 from output_templates import output_templates
 from schemagroup import MetaData
@@ -497,11 +497,14 @@
         except AttributeError:
             return str(datum)
               
-    def tabular_output(self, outformat, rowlimit):
-        try:
-            self.tblname = self.tableNameFinder.search(self.querytext).group(1)
-        except AttributeError:
-            self.tblname = ''
+    def tabular_output(self, outformat, rowlimit, tblname=None):
+        if tblname:
+            self.tblname = tblname
+        else:
+            try:
+                self.tblname = self.tableNameFinder.search(self.querytext).group(1)
+            except AttributeError:
+                self.tblname = ''
         if outformat in output_templates:
             self.colnamelen = max(len(colname) for colname in self.colnames)
             result = output_templates[outformat].generate(formattedForSql=self.formattedForSql, **self.__dict__)        
@@ -524,7 +527,7 @@
             plot.draw()
             return ''
         else:
-            result = self.pmatrix(self.rows, self.curs.description, 
+            result = self.pmatrix(self.rows, 
                                   self.maxfetch, heading=self.heading, 
                                   restructuredtext = (outformat == '\\r'))
         return result
@@ -991,27 +994,61 @@
                                 name_printed = True
                             self.poutput('%d: %s' % (line_num, line))
             
+    def _col_type_descriptor(self, col):
+        if 'precision' in col:
+            return '%s(%d,%d)' % (col['type'], col['length'], col['precision'])
+        elif 'length' in col:
+            return '%s(%d)' % (col['type'], col['length'])
+        else:
+            return col['type']
+        
+    def _key_columns(self, tbl, type):
+        columns = [c['columns'] for c in tbl.constraints.values() if c['type'] == type]
+        if columns:
+            return reduce(list.extend, columns)
+        
     @options([all_users_option,
-              make_option('-l', '--long', action='store_true', help='include column #, comments')])
+              make_option('-l', '--long', action='store_true', help='include column #, comments'),
+              make_option('-A', '--alpha', action='store_true', help='List columns alphabetically')])
     def do_describe(self, arg, opts):
         opts.exact = True
+        if opts.alpha:
+            sortkey = operator.itemgetter('name')
+        else:
+            sortkey = operator.itemgetter('sequence')
         for m in self._matching_database_objects(arg, opts):
             self.tblname = m.descriptor(qualified=opts.get('all'))
             self.pfeedback(self.tblname)
-            if hasattr(m.db_object, 'columns') and not isinstance(m.db_object.columns, tuple): # drop once gerald returns column dicts for views
-                cols = m.db_object.columns.values()                
-                cols.sort() # on column order... or alphabetical with an option
+            if hasattr(m.db_object, 'columns') and not isinstance(m.db_object.columns, list): # drop once gerald returns column dicts for views
+                cols = sorted(m.db_object.columns.values(), key=sortkey)
                 if opts.long:
-                    self.colnames = 'N Name Null? Type Default Comments'.split()
-                    self.rows = [(col['sequence'], col['name'], col['nullable'], 
-                                  col['type'], col.get(default), col.get(comment)) 
+                    primary_key_columns = self._key_columns(m.db_object, 'Primary')
+                    unique_key_columns = self._key_columns(m.db_object, 'Unique')
+                    self.colnames = 'N Name Null? Type Key Default Comments'.split()
+                    self.rows = [(col['sequence'], col['name'], (col['nullable'] and 'NULL') or 'NOT NULL',
+                                  self._col_type_descriptor(col), 
+                                  ((col['name'] in primary_key_columns) and 'P') or
+                                  ((col['name'] in unique_key_columns) and 'U') or '',
+                                  col.get('default'), col.get('comment')) 
                                  for col in cols]
                 else:
-                    self.colnames = 'Name Null? Type'.split()
-                    self.rows = [(col['name'], col['nullable'], col['type']) 
+                    self.colnames = 'Name Nullable Type'.split()
+                    self.rows = [(col['name'], (col['nullable'] and 'NULL') or 'NOT NULL', self._col_type_descriptor(col)) 
                                  for col in cols]
                 self.coltypes = [str] * len(self.colnames)
-                self.tabular_output(arg.parsed.terminator, self.rowlimit(arg))
+                self.poutput(self.tabular_output(arg.parsed.terminator, self.rowlimit(arg), self.tblname) + '\n\n')
+            elif hasattr(m.db_object, 'increment_by'):
+                self.colnames = 'name min_value max_value increment_by'.split()
+                self.coltypes = [str, int, int, int]
+                self.rows = [(getattr(m.db_object, p) for p in self.colnames)]
+                self.poutput(self.tabular_output(arg.parsed.terminator, self.rowlimit(arg), self.tblname) + '\n\n')
+            elif hasattr(m.db_object, 'source'):
+                end_heading = re.compile(r'\bDECLARE|BEGIN\b', re.IGNORECASE)
+                for (index, (ln, line)) in enumerate(m.db_object.source):
+                    if end_heading.search(line):
+                        break
+                self.poutput(''.join(l for (ln, l) in m.db_object.source[:index]))
+                        
             
     def do_deps(self, arg):
         '''Lists all objects that are dependent upon the object.'''
@@ -1520,7 +1557,6 @@
                     sql = self.parsed(sql, 
                                           terminator=arg.parsed.terminator or ';',
                                           suffix=arg.parsed.suffix)
-                    import pdb; pdb.set_trace()
                     self.do_select(sql)
                 elif hasattr(m.db_object, 'source'):
                     for (line_num, line) in m.db_object.source:
--- a/sqlpython/sqlpython.py	Thu Oct 08 11:13:41 2009 -0400
+++ b/sqlpython/sqlpython.py	Thu Oct 08 15:08:44 2009 -0400
@@ -331,23 +331,19 @@
                     color = 'cyan'
             return self.colorcodes[color][True] + val + self.colorcodes[color][False]        
         return val
-    def pmatrix(self,rows,desc,maxlen=30,heading=True,restructuredtext=False):
+    def pmatrix(self,rows,maxlen=30,heading=True,restructuredtext=False):
         '''prints a matrix, used by sqlpython to print queries' result sets'''
-        names = []
-        maxen = []
+        names = self.colnames
+        maxen = [len(n) for n in self.colnames]
         toprint = []
-        for d in desc:
-            n = d[0]
-            names.append(n)      # list col names
-            maxen.append(len(n)) # col length
-        rcols = range(len(desc))
+        rcols = range(len(self.colnames))
         rrows = range(len(rows))
         for i in rrows:          # loops for all rows
             rowsi = map(str, rows[i]) # current row to process
             split = []                # service var is row split is needed
             mustsplit = 0             # flag 
             for j in rcols:
-                if str(desc[j][1]) == "<type 'cx_Oracle.BINARY'>":  # handles RAW columns
+                if str(self.coltypes[j]) == "<type 'cx_Oracle.BINARY'>":  # handles RAW columns
                     rowsi[j] = binascii.b2a_hex(rowsi[j])
                 maxen[j] = max(maxen[j], len(rowsi[j]))    # computes max field length
                 if maxen[j] <= maxlen:
@@ -369,7 +365,8 @@
             rrows2 = range(len(toprint))
             for j in rrows2:
                 val = toprint[j][i]
-                if str(desc[i][1]) == "<type 'cx_Oracle.NUMBER'>":  # right align numbers
+                #import pdb; pdb.set_trace()
+                if str(self.coltypes[i]) == "<type 'cx_Oracle.NUMBER'>":  # right align numbers - but must generalize!
                     toprint[j][i] = (" " * (maxcol-len(val))) + val
                 else:
                     toprint[j][i] = val + (" " * (maxcol-len(val)))