changeset 3:cd23cd62de3c

history working pretty well
author devlinjs@FA7CZA6N1254998.wrightpatterson.afmc.ds.af.mil
date Mon, 03 Dec 2007 16:43:43 -0500
parents 59903dcaf327
children 23c3a58d7804
files sqlpyPlus.py sqlpython.py
diffstat 2 files changed, 97 insertions(+), 48 deletions(-) [+]
line wrap: on
line diff
--- a/sqlpyPlus.py	Mon Dec 03 14:29:16 2007 -0500
+++ b/sqlpyPlus.py	Mon Dec 03 16:43:43 2007 -0500
@@ -261,12 +261,51 @@
                 print 'Bind variable %s not defined.' % (varname)                
     return result
 
+class HistoryItem(str):
+    def __init__(self, instr):
+        str.__init__(self, instr)
+        self.lowercase = self.lower()
+        self.idx = None
+    def pr(self):
+        print '-------------------------[%d]' % (self.idx)
+        print self
+        
+class History(list):
+    def append(self, new):
+        new = HistoryItem(new)
+        list.append(self, new)
+        new.idx = len(self)
+    def extend(self, new):
+        for n in new:
+            self.append(n)
+    def get(self, getme):
+        try:
+            idx = int(getme)
+            try:
+                return [self[idx-1]]
+            except IndexError:
+                return []
+        except ValueError:  # search for a string
+            try:
+                getme = getme.strip()
+            except:
+                print "Don't know how to handle %s." % (str(getme))
+                return 
+            if getme.startswith(r'/') and getme.endswith(r'/'):
+                finder = re.compile(getme[1:-1], re.DOTALL | re.MULTILINE | re.IGNORECASE)
+                def isin(hi):
+                    return finder.search(hi)
+            else:
+                def isin(hi):
+                    return (getme.lower() in hi.lowercase)
+            return [itm for itm in self[:-1] if isin(itm)]
+        
 class sqlpyPlus(sqlpython.sqlpython):
     def __init__(self):
         sqlpython.sqlpython.__init__(self)
         self.binds = CaselessDict()
         self.sqlBuffer = []
-        self.history = []
+        self.history = History()
         self.settable = ['maxtselctrows', 'maxfetch', 'autobind', 'failover', 'timeout'] # settables must be lowercase
         self.stdoutBeforeSpool = sys.stdout
         self.spoolFile = None
@@ -328,6 +367,7 @@
             statement = ' '.join(args)      
             if args[0] in self.singleline:
                 statement = sqlpython.finishStatement(statement)
+            self.history.append(statement)
             return statement
         except Exception:
             return line
@@ -611,44 +651,50 @@
 
     bufferPosPattern = re.compile('\d+')
     rangeIndicators = ('-',':')
-    def bufferPositions(self, arg):
-        if not self.sqlBuffer:
-            return []
-        arg = arg.strip(self.terminator)
-        arg = arg.strip()
+    
+    def last_matching_command(self, arg):
         if not arg:
-            return [0]
-        arg = arg.strip().lower()
-        if arg in ('*', 'all', '-', ':'):
-            return range(len(self.sqlBuffer))
-
-        edges = [e for e in self.bufferPosPattern.findall(arg)]
-        edges = [int(e) for e in edges]
-        if len(edges) > 1:
-            edges = edges[:2]
+            return self.history[-2]
         else:
-            if arg[0] in self.rangeIndicators or arg[-1] in self.rangeIndicators:
-                edges.append(0)
-        edges.sort()
-        start = max(edges[0], 0)
-        end = min(edges[-1], len(self.sqlBuffer)-1)
-        return range(start, end+1)
+            history = self.history.get(arg)
+            if history:
+                return history[-1]
+        return None
+        
     def do_run(self, arg):
-        'run [N]: runs the SQL that was run N commands ago'	
-        for pos in self.bufferPositions(arg):
-            self.onecmd(self.sqlBuffer[-1-pos])
+        """run [arg]: re-runs an earlier command
+        
+        no arg -> run most recent command
+        arg is integer -> run one history item, by index
+        arg is string -> run most recent command by string search
+        arg is /enclosed in forward-slashes/ -> run most recent by regex
+        """        
+        'run [N]: runs the SQL that was run N commands ago'
+        runme = self.last_matching_command(arg)
+        print runme
+        self.onecmd(runme)
+    do_r = do_run
     def do_history(self, arg):
-        for (i, itm) in enumerate(self.history):
-            print '-------------------------[%d]' % (i+1)
-            print itm
+        """history [arg]: lists past commands issued
+        
+        no arg -> list all
+        arg is integer -> list one history item, by index
+        arg is string -> string search
+        arg is /enclosed in forward-slashes/ -> regular expression search
+        """
+        if arg:
+            history = self.history.get(arg)
+        else:
+            history = self.history
+        for hi in history:
+            hi.pr()
     def do_list(self, arg):
-        'list [N]: lists the SQL that was run N commands ago'
-        for pos in self.bufferPositions(arg):
-            print '*** %i statements ago ***' % pos
-            print self.sqlBuffer[-1-pos]
+        """list: lists single most recent command issued"""
+        self.last_matching_command(None).pr()
+    do_hi = do_history
+    do_l = do_list
     def load(self, fname):
         """Pulls command(s) into sql buffer.  Returns number of commands loaded."""
-        initialLength = len(self.sqlBuffer)
         try:
             f = open(fname, 'r')
         except IOError, e:
@@ -659,15 +705,13 @@
                 return 0
         txt = f.read()
         f.close()
-        self.sqlBuffer.extend(commandSeparator.separate(txt))                           
-        return len(self.sqlBuffer) - initialLength
+        result = commandSeparator.separate(txt)
+        self.history.extend(result) 
+        return len(result)
     def do_ed(self, arg):
         'ed [N]: brings up SQL from N commands ago in text editor, and puts result in SQL buffer.'
-        fname = 'mysqlpy_temp.sql'
-        try:
-            buffer = self.sqlBuffer[-1 - (int(arg or 0))]
-        except IndexError:
-            buffer = ''
+        fname = 'sqlpython_temp.sql'
+        buffer = self.last_matching_command(arg)
         f = open(fname, 'w')
         f.write(buffer)
         f.close()
@@ -681,11 +725,10 @@
             self.do_list('1-%d' % (commandsLoaded-1))
     def do_getrun(self, fname):
         'Brings SQL commands from a file to the in-memory SQL buffer, and executes them.'
-        commandNums = range(self.load(fname))
-        commandNums.reverse()
-        for commandNum in commandNums:
-            self.do_run(str(commandNum))
-            self.sqlBuffer.pop()
+        newCommands = self.load(fname) * -1
+        if newCommands:
+            for command in self.history[newCommands:]:
+                self.onecmd(command)
     def do_psql(self, arg):
         '''Shortcut commands emulating psql's backslash commands.
         
@@ -716,9 +759,12 @@
             print 'psql command \%s not yet supported.' % abbrev        
     def do_save(self, fname):
         'save FILENAME: Saves most recent SQL command to disk.'
-        f = open(fname, 'w')
-        f.write(self.sqlBuffer[-1])
-        f.close()
+        try:
+            f = open(fname, 'w')
+            f.write(self.sqlBuffer[-1])
+            f.close()
+        except Exception, e:
+            print 'Error saving %s: %s' % (fname, str(e))
         
     def do_print(self, arg):
         'print VARNAME: Show current value of bind variable VARNAME.'
--- a/sqlpython.py	Mon Dec 03 14:29:16 2007 -0500
+++ b/sqlpython.py	Mon Dec 03 16:43:43 2007 -0500
@@ -134,7 +134,10 @@
 
 def findTerminator(statement):
     m = stmtEndFinder.search(statement)
-    return m.groups()
+    if m:
+        return m.groups()
+    else:
+        return statement, None, None
     
 def pmatrix(rows,desc,maxlen=30):
     '''prints a matrix, used by sqlpython to print queries' result sets'''