changeset 343:d185c87766bd

merged with laptop changes
author Catherine Devlin <catherine.devlin@gmail.com>
date Tue, 14 Apr 2009 16:44:26 -0400
parents 80a1976decf2 (current diff) aa16fe026f01 (diff)
children c2e3223703f2
files
diffstat 4 files changed, 92 insertions(+), 51 deletions(-) [+]
line wrap: on
line diff
--- a/setup.py	Tue Apr 14 16:11:59 2009 -0400
+++ b/setup.py	Tue Apr 14 16:44:26 2009 -0400
@@ -17,7 +17,7 @@
       url="http://packages.python.org/sqlpython",
       packages=find_packages(),
       include_package_data=True,    
-      install_requires=['pyparsing','cmd2>=0.5.2','cx_Oracle','genshi>=0.5','sqlalchemy'],
+      install_requires=['pyparsing','cmd2>=0.5.2','cx_Oracle','genshi>=0.5','sqlalchemy',],
       keywords = 'client oracle database',
       license = 'MIT',
       platforms = ['any'],
--- a/sqlpython/exampleSession.txt	Tue Apr 14 16:11:59 2009 -0400
+++ b/sqlpython/exampleSession.txt	Tue Apr 14 16:44:26 2009 -0400
@@ -7,6 +7,9 @@
 SQL.No_Connection> connect testschema/testschema@orcl
 0:testschema@orcl> drop table play;
 /.*/
+0:testschema@orcl> set colors off
+colors - was: True
+now: False
 0:testschema@orcl> CREATE TABLE play (
 >   title   VARCHAR2(40) CONSTRAINT xpk_play PRIMARY KEY,
 >   author  VARCHAR2(40));
--- a/sqlpython/sqlpyPlus.py	Tue Apr 14 16:11:59 2009 -0400
+++ b/sqlpython/sqlpyPlus.py	Tue Apr 14 16:44:26 2009 -0400
@@ -28,6 +28,7 @@
 from output_templates import output_templates
 from metadata import metaqueries
 from plothandler import Plot
+from sqlpython import Parser
 try:
     import pylab
 except (RuntimeError, ImportError):
@@ -227,42 +228,6 @@
     def pop(self, key, def_val=None):
         return dict.pop(self, key.lower(), def_val)
 
-class Parser(object):
-    comment_def = "--" + ~ ('-' + pyparsing.CaselessKeyword('begin')) + pyparsing.ZeroOrMore(pyparsing.CharsNotIn("\n"))    
-    def __init__(self, scanner, retainSeparator=True):
-        self.scanner = scanner
-        self.scanner.ignore(pyparsing.sglQuotedString)
-        self.scanner.ignore(pyparsing.dblQuotedString)
-        self.scanner.ignore(self.comment_def)
-        self.scanner.ignore(pyparsing.cStyleComment)
-        self.retainSeparator = retainSeparator
-    def separate(self, txt):
-        itms = []
-        for (sqlcommand, start, end) in self.scanner.scanString(txt):
-            if sqlcommand:
-                if type(sqlcommand[0]) == pyparsing.ParseResults:
-                    if self.retainSeparator:
-                        itms.append("".join(sqlcommand[0]))
-                    else:
-                        itms.append(sqlcommand[0][0])
-                else:
-                    if sqlcommand[0]:
-                        itms.append(sqlcommand[0])
-        return itms
-
-bindScanner = Parser(pyparsing.Literal(':') + pyparsing.Word( pyparsing.alphanums + "_$#" ))   
-    
-def findBinds(target, existingBinds, givenBindVars = {}):
-    result = givenBindVars
-    for finding, startat, endat in bindScanner.scanner.scanString(target):
-        varname = finding[1]
-        try:
-            result[varname] = existingBinds[varname]
-        except KeyError:
-            if not givenBindVars.has_key(varname):
-                print 'Bind variable %s not defined.' % (varname)                
-    return result
-
 class ResultSet(list):
     pass
 
@@ -445,14 +410,7 @@
             return str(datum)
               
     def output(self, outformat, rowlimit):
-        try:
-            self.tblname = self.tableNameFinder.search(self.querytext).group(1)
-        except AttributeError:
-            self.tblname = ''
-        self.colnames = [d[0] for d in self.curs.description]
         if outformat in output_templates:
-            self.colnamelen = max(len(colname) for colname in self.colnames)
-            self.coltypes = [d[1] for d in self.curs.description]
             result = output_templates[outformat].generate(formattedForSql=self.formattedForSql, **self.__dict__)        
         elif outformat == '\\t': # transposed
             rows = [self.colnames]
@@ -671,7 +629,17 @@
                 return
             total_len -= len(rset)
             self.pystate['r'][i] = []
-            
+    
+    def set_query_metadata(self):
+        try:
+            self.tblname = self.tableNameFinder.search(self.querytext).group(1)
+        except AttributeError:
+            self.tblname = ''
+        self.colnames = [d[0] for d in self.curs.description]
+        if outformat in output_templates:
+            self.colnamelen = max(len(colname) for colname in self.colnames)
+            # self.coltypes = [d[1] for d in self.curs.description]   never used?
+                
     def do_select(self, arg, bindVarsIn=None, terminator=None):
         """Fetch rows from a table.
 
@@ -690,7 +658,7 @@
             self.perror("Specify desired number of rows after terminator (not '%s')" % arg.parsed.suffix)
         if arg.parsed.terminator == '\\t':
             rowlimit = rowlimit or self.maxtselctrows
-        self.varsUsed = findBinds(arg, self.binds, bindVarsIn)
+        self.varsUsed = self.findBinds(arg, bindVarsIn)
         if self.wildsql:
             selecttext = self.expandWildSql(arg)
         else:
@@ -713,6 +681,7 @@
                 row.resultset = resultset
             self.pystate['r'].append(resultset)
             self.age_out_resultsets()
+            self.set_query_metadata()
             self.stdout.write('\n%s\n' % (self.output(arg.parsed.terminator, rowlimit)))
         if self.rc == 0:
             self.pfeedback('\nNo rows Selected.\n')
@@ -967,7 +936,31 @@
                     sql = self.parsed(descQueries['PackageObjArgs'][0], terminator=arg.parsed.terminator or ';', suffix=arg.parsed.suffix)
                     self.do_select(sql, bindVarsIn={'package_name':object_name, 'owner':owner, 'object_name':packageObj_name})
 
-
+    def _str_datatype_(self, datatype, length, scale, precision):
+        if precision is not None:
+            result = '%s(%s,%s)' % (datatype, scale, precision)
+        elif length is not None:
+            result = '%s(%s)' % (datatype, length)
+        else:
+            result = datatype
+        return result     
+                            
+    @options([make_option('-l', '--long', action='store_true', help='include column #, comments')])
+    def do_describe(self, arg, opts):
+        schema = self.connections[self.connection_number]['gerald']().schema
+        target = arg.upper().strip()
+        for (objname, obj) in schema.items():
+            if objname.upper() == target:
+                self.poutput(objname)
+                if hasattr(obj, 'columns'):
+                    self.tblname = objname
+                    columns = obj.columns.values()
+                    columns.sort()
+                    self.rows = [c[0], c[1], self._str_datatype_(c[2], c[3], c[4], c[5]), c[6], c[7] for c in columns]
+                    self.colnames = 'position name type nullable default'.split()
+                    self.colnamelen = max(len(colname) for colname in self.colnames)
+                    self.output(arg.parsed.terminator, rowlimit)
+        
     def do_deps(self, arg):
         '''Lists all objects that are dependent upon the object.'''
         target = arg.upper()
@@ -1268,7 +1261,7 @@
         if arg.startswith(':'):
             self.do_setbind(arg[1:])
         else:
-            varsUsed = findBinds(arg, self.binds, {})
+            varsUsed = self.findBinds(arg, {})
             try:
                 self.curs.execute('begin\n%s;end;' % arg, varsUsed)
             except Exception, e:
--- a/sqlpython/sqlpython.py	Tue Apr 14 16:11:59 2009 -0400
+++ b/sqlpython/sqlpython.py	Tue Apr 14 16:44:26 2009 -0400
@@ -8,10 +8,34 @@
 # Best used with the companion modules sqlpyPlus and mysqlpy 
 # See also http://twiki.cern.ch/twiki/bin/view/PSSGroup/SqlPython
 
-import cmd2,getpass,binascii,cx_Oracle,re,os
-import sqlpyPlus, sqlalchemy
+import cmd2,getpass,binascii,cx_Oracle,re,os,functools
+import sqlpyPlus, sqlalchemy, pyparsing, gerald
 __version__ = '1.6.4'    
 
+class Parser(object):
+    comment_def = "--" + ~ ('-' + pyparsing.CaselessKeyword('begin')) + pyparsing.ZeroOrMore(pyparsing.CharsNotIn("\n"))    
+    def __init__(self, scanner, retainSeparator=True):
+        self.scanner = scanner
+        self.scanner.ignore(pyparsing.sglQuotedString)
+        self.scanner.ignore(pyparsing.dblQuotedString)
+        self.scanner.ignore(self.comment_def)
+        self.scanner.ignore(pyparsing.cStyleComment)
+        self.retainSeparator = retainSeparator
+    def separate(self, txt):
+        itms = []
+        for (sqlcommand, start, end) in self.scanner.scanString(txt):
+            if sqlcommand:
+                if type(sqlcommand[0]) == pyparsing.ParseResults:
+                    if self.retainSeparator:
+                        itms.append("".join(sqlcommand[0]))
+                    else:
+                        itms.append(sqlcommand[0][0])
+                else:
+                    if sqlcommand[0]:
+                        itms.append(sqlcommand[0])
+        return itms
+
+    
 class sqlpython(cmd2.Cmd):
     '''A python module to reproduce Oracle's command line with focus on customization and extention'''
 
@@ -71,13 +95,19 @@
         self.curs = None
         self.no_connection()        
             
+    gerald_classes = {'oracle': (gerald.OracleSchema, 'oracle',),
+                      'postgres': (gerald.PostgresSchema, 'public'),
+                      'mysql': (gerald.MySQLSchema, 'nil')}
     def url_connect(self, arg):
         eng = sqlalchemy.create_engine(arg) 
         self.conn = eng.connect().connection
         conn  = {'conn': self.conn, 'prompt': self.prompt, 'dbname': eng.url.database,
                  'rdbms': eng.url.drivername, 'user': eng.url.username or '', 
                  'eng': eng}
+        gerclass = self.gerald_classes[eng.url.drivername]
+        conn['gerald'] = functools.partial(gerclass[0], gerclass[1], arg.split('/?')[:1][0])
         return conn
+    
     def ora_connect(self, arg):
         modeval = 0
         oraserv = None
@@ -247,8 +277,23 @@
     
     terminatorSearchString = '|'.join('\\' + d.split()[0] for d in do_terminators.__doc__.splitlines())
         
+    bindScanner = {'oracle': Parser(pyparsing.Literal(':') + pyparsing.Word( pyparsing.alphanums + "_$#" )),
+                   'postgres': Parser(pyparsing.Literal('%(') + 
+                                      pyparsing.Word(pyparsing.alphanums + "_$#") + ')s')}
+    def findBinds(self, target, givenBindVars = {}):
+        result = givenBindVars
+        if self.rdbms in self.bindScanner:
+            for finding, startat, endat in self.bindScanner[self.rdbms].scanner.scanString(target):
+                varname = finding[1]
+                try:
+                    result[varname] = self.binds[varname]
+                except KeyError:
+                    if not givenBindVars.has_key(varname):
+                        print 'Bind variable %s not defined.' % (varname)
+        return result
+
     def default(self, arg):
-        self.varsUsed = sqlpyPlus.findBinds(arg, self.binds, givenBindVars={})
+        self.varsUsed = self.findBinds(arg, givenBindVars={})
         ending_args = arg.lower().split()[-2:]
         if 'end' in ending_args:
             command = '%s %s;'