comparison sqlpyPlus.py @ 0:9c87fa772ec1

before big refactor
author catherine@serenity.wpafb.af.mil
date Fri, 30 Nov 2007 13:04:51 -0500
parents
children 8fa146b9a2d7
comparison
equal deleted inserted replaced
-1:000000000000 0:9c87fa772ec1
1 """sqlpyPlus - extra features (inspired by Oracle SQL*Plus) for Luca Canali's sqlpython.py
2
3 Features include:
4 - SQL*Plus-style bind variables
5 - Query result stored in special bind variable ":_" if one row, one item
6 - SQL buffer with list, run, ed, get, etc.; unlike SQL*Plus, buffer stores session's full history
7 - @script.sql loads and runs (like SQL*Plus)
8 - ! runs operating-system command
9 - show and set to control sqlpython parameters
10 - SQL*Plus-style describe, spool
11 - write sends query result directly to file
12 - comments shows table and column comments
13 - compare ... to ... graphically compares results of two queries
14 - commands are case-insensitive
15
16 Use 'help' within sqlpython for details.
17
18 Compatible with sqlpython v1.3
19
20 Set bind variables the hard (SQL*Plus) way
21 exec :b = 3
22 or with a python-like shorthand
23 :b = 3
24
25 - catherinedevlin.blogspot.com May 31, 2006
26 """
27 # note in cmd.cmd about supporting emacs commands?
28
29 queries = {
30 'resolve': """
31 SELECT object_type, object_name, owner FROM (
32 SELECT object_type, object_name, user owner, 1 priority
33 FROM user_objects
34 WHERE object_name = :objName
35 UNION ALL
36 SELECT ao.object_type, ao.object_name, ao.owner, 2 priority
37 FROM all_objects ao
38 JOIN user_synonyms us ON (us.table_owner = ao.owner AND us.table_name = ao.object_name)
39 WHERE us.synonym_name = :objName
40 AND ao.object_type != 'SYNONYM'
41 UNION ALL
42 SELECT ao.object_type, ao.object_name, ao.owner, 3 priority
43 FROM all_objects ao
44 JOIN all_synonyms asyn ON (asyn.table_owner = ao.owner AND asyn.table_name = ao.object_name)
45 WHERE asyn.synonym_name = :objName
46 AND ao.object_type != 'SYNONYM'
47 AND asyn.owner = 'PUBLIC'
48 ) ORDER BY priority ASC""",
49 'descTable': """
50 atc.column_name,
51 CASE atc.nullable WHEN 'Y' THEN 'NULL' ELSE 'NOT NULL' END "Null?",
52 atc.data_type ||
53 CASE atc.data_type WHEN 'DATE' THEN ''
54 ELSE '(' ||
55 CASE atc.data_type WHEN 'NUMBER' THEN TO_CHAR(atc.data_precision) ||
56 CASE atc.data_scale WHEN 0 THEN ''
57 ELSE ',' || TO_CHAR(atc.data_scale) END
58 ELSE TO_CHAR(atc.data_length) END
59 END ||
60 CASE atc.data_type WHEN 'DATE' THEN '' ELSE ')' END
61 data_type
62 FROM all_tab_columns atc
63 WHERE atc.table_name = :object_name
64 AND atc.owner = :owner
65 ORDER BY atc.column_id;""",
66 'PackageObjects':"""
67 SELECT DISTINCT object_name
68 FROM all_arguments
69 WHERE package_name = :package_name
70 AND owner = :owner""",
71 'PackageObjArgs':"""
72 object_name,
73 argument_name,
74 data_type,
75 in_out,
76 default_value
77 FROM all_arguments
78 WHERE package_name = :package_name
79 AND object_name = :object_name
80 AND owner = :owner
81 AND argument_name IS NOT NULL
82 ORDER BY sequence""",
83 'descProcedure':"""
84 argument_name,
85 data_type,
86 in_out,
87 default_value
88 FROM all_arguments
89 WHERE object_name = :object_name
90 AND owner = :owner
91 AND package_name IS NULL
92 AND argument_name IS NOT NULL
93 ORDER BY sequence;""",
94 'tabComments': """
95 SELECT comments
96 FROM all_tab_comments
97 WHERE owner = :owner
98 AND table_name = :table_name""",
99 'colComments': """
100 atc.column_name,
101 acc.comments
102 FROM all_tab_columns atc
103 JOIN all_col_comments acc ON (atc.owner = acc.owner and atc.table_name = acc.table_name and atc.column_name = acc.column_name)
104 WHERE atc.table_name = :object_name
105 AND atc.owner = :owner
106 ORDER BY atc.column_id;""",
107 }
108
109 import sys, os, re, sqlpython, cx_Oracle, pyparsing
110
111 if float(sys.version[:3]) < 2.3:
112 def enumerate(lst):
113 return zip(range(len(lst)), lst)
114
115 class SoftwareSearcher(object):
116 def __init__(self, softwareList, purpose):
117 self.softwareList = softwareList
118 self.purpose = purpose
119 self.software = None
120 def invoke(self, *args):
121 if not self.software:
122 (self.software, self.invokeString) = self.find()
123 argTuple = tuple([self.software] + list(args))
124 os.system(self.invokeString % argTuple)
125 def find(self):
126 if self.purpose == 'text editor':
127 software = os.environ.get('EDITOR')
128 if software:
129 return (software, '%s %s')
130 for (n, (software, invokeString)) in enumerate(self.softwareList):
131 if os.path.exists(software):
132 if n > (len(self.softwareList) * 0.7):
133 print """
134
135 Using %s. Note that there are better options available for %s,
136 but %s couldn't find a better one in your PATH.
137 Feel free to open up %s
138 and customize it to find your favorite %s program.
139
140 """ % (software, self.purpose, __file__, __file__, self.purpose)
141 return (software, invokeString)
142 stem = os.path.split(software)[1]
143 for p in os.environ['PATH'].split(os.pathsep):
144 if os.path.exists(os.sep.join([p, stem])):
145 return (stem, invokeString)
146 raise (OSError, """Could not find any %s programs. You will need to install one,
147 or customize %s to make it aware of yours.
148 Looked for these programs:
149 %s""" % (self.purpose, __file__, "\n".join([s[0] for s in self.softwareList])))
150 #v2.4: %s""" % (self.purpose, __file__, "\n".join(s[0] for s in self.softwareList)))
151
152 softwareLists = {
153 'diff/merge': [
154 ('/usr/bin/meld',"%s %s %s"),
155 ('/usr/bin/kdiff3',"%s %s %s"),
156 (r'C:\Program Files\Araxis\Araxis Merge v6.5\Merge.exe','"%s" %s %s'),
157 (r'C:\Program Files\TortoiseSVN\bin\TortoiseMerge.exe', '"%s" /base:"%s" /mine:"%s"'),
158 ('FileMerge','%s %s %s'),
159 ('kompare','%s %s %s'),
160 ('WinMerge','%s %s %s'),
161 ('xxdiff','%s %s %s'),
162 ('fldiff','%s %s %s'),
163 ('gtkdiff','%s %s %s'),
164 ('tkdiff','%s %s %s'),
165 ('gvimdiff','%s %s %s'),
166 ('diff',"%s %s %s"),
167 (r'c:\windows\system32\comp.exe',"%s %s %s")],
168 'text editor': [
169 ('gedit', '%s %s'),
170 ('textpad', '%s %s'),
171 ('notepad.exe', '%s %s'),
172 ('pico', '%s %s'),
173 ('emacs', '%s %s'),
174 ('vim', '%s %s'),
175 ('vi', '%s %s'),
176 ('ed', '%s %s'),
177 ('edlin', '%s %s')
178 ]
179 }
180
181 diffMergeSearcher = SoftwareSearcher(softwareLists['diff/merge'],'diff/merge')
182 editSearcher = SoftwareSearcher(softwareLists['text editor'], 'text editor')
183 editor = os.environ.get('EDITOR')
184 if editor:
185 editSearcher.find = lambda: (editor, "%s %s")
186
187 class CaselessDict(dict):
188 """dict with case-insensitive keys.
189
190 Posted to ASPN Python Cookbook by Jeff Donner - http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/66315"""
191 def __init__(self, other=None):
192 if other:
193 # Doesn't do keyword args
194 if isinstance(other, dict):
195 for k,v in other.items():
196 dict.__setitem__(self, k.lower(), v)
197 else:
198 for k,v in other:
199 dict.__setitem__(self, k.lower(), v)
200 def __getitem__(self, key):
201 return dict.__getitem__(self, key.lower())
202 def __setitem__(self, key, value):
203 dict.__setitem__(self, key.lower(), value)
204 def __contains__(self, key):
205 return dict.__contains__(self, key.lower())
206 def has_key(self, key):
207 return dict.has_key(self, key.lower())
208 def get(self, key, def_val=None):
209 return dict.get(self, key.lower(), def_val)
210 def setdefault(self, key, def_val=None):
211 return dict.setdefault(self, key.lower(), def_val)
212 def update(self, other):
213 for k,v in other.items():
214 dict.__setitem__(self, k.lower(), v)
215 def fromkeys(self, iterable, value=None):
216 d = CaselessDict()
217 for k in iterable:
218 dict.__setitem__(d, k.lower(), value)
219 return d
220 def pop(self, key, def_val=None):
221 return dict.pop(self, key.lower(), def_val)
222
223 class NotSettableError(Exception):
224 None
225
226 class Parser(object):
227 comment_def = "--" + pyparsing.ZeroOrMore(pyparsing.CharsNotIn("\n"))
228 def __init__(self, scanner, retainSeparator=True):
229 self.scanner = scanner
230 self.scanner.ignore(pyparsing.sglQuotedString)
231 self.scanner.ignore(pyparsing.dblQuotedString)
232 self.scanner.ignore(self.comment_def)
233 self.scanner.ignore(pyparsing.cStyleComment)
234 self.retainSeparator = retainSeparator
235 def separate(self, txt):
236 itms = []
237 for (sqlcommand, start, end) in self.scanner.scanString(txt):
238 if sqlcommand:
239 if type(sqlcommand[0]) == pyparsing.ParseResults:
240 if self.retainSeparator:
241 itms.append("".join(sqlcommand[0]))
242 else:
243 itms.append(sqlcommand[0][0])
244 else:
245 if sqlcommand[0]:
246 itms.append(sqlcommand[0])
247 return itms
248
249 pipeSeparator = Parser(pyparsing.SkipTo((pyparsing.Literal('|') ^ pyparsing.StringEnd()), include=True), retainSeparator=False)
250 bindScanner = Parser(pyparsing.Literal(':') + pyparsing.Word( pyparsing.alphanums + "_$#" ))
251 commandSeparator = Parser(pyparsing.SkipTo((pyparsing.Literal(';') ^ pyparsing.StringEnd()), include=True))
252
253 def findBinds(target, existingBinds, givenBindVars = {}):
254 result = givenBindVars
255 for finding, startat, endat in bindScanner.scanner.scanString(target):
256 varname = finding[1]
257 try:
258 result[varname] = existingBinds[varname]
259 except KeyError:
260 if not givenBindVars.has_key(varname):
261 print 'Bind variable %s not defined.' % (varname)
262 return result
263
264 class sqlpyPlus(sqlpython.sqlpython):
265 def __init__(self):
266 sqlpython.sqlpython.__init__(self)
267 self.binds = CaselessDict()
268 self.sqlBuffer = []
269 self.settable = ['maxtselctrows', 'maxfetch', 'autobind', 'failover', 'timeout'] # settables must be lowercase
270 self.stdoutBeforeSpool = sys.stdout
271 self.spoolFile = None
272 self.autobind = False
273 self.failover = False
274
275 def default(self, arg, do_everywhere=False):
276 sqlpython.sqlpython.default(self, arg, do_everywhere)
277 self.sqlBuffer.append(self.query)
278
279 # overrides cmd's parseline
280 shortcuts = {'?': 'help', '@': 'getrun', '!': 'shell', ':': 'setbind', '\\': 'psql'}
281 def parseline(self, line):
282 """Parse the line into a command name and a string containing
283 the arguments. Returns a tuple containing (command, args, line).
284 'command' and 'args' may be None if the line couldn't be parsed.
285 Overrides cmd.cmd.parseline to accept variety of shortcuts.."""
286 line = line.strip()
287 if not line:
288 return None, None, line
289 shortcut = self.shortcuts.get(line[0])
290 if shortcut:
291 cmd, arg = shortcut, line[1:].strip()
292 else:
293 i, n = 0, len(line)
294 while i < n and line[i] in self.identchars: i = i+1
295 cmd, arg = line[:i], line[i:].strip()
296 if cmd.lower() in ('select', 'sleect', 'insert', 'update', 'delete', 'describe',
297 'desc', 'comments') \
298 and not hasattr(self, 'curs'):
299 print 'Not connected.'
300 return '', '', ''
301 return cmd, arg, line
302
303 def precmd(self, line):
304 """Hook method executed just before the command line is
305 interpreted, but after the input prompt is generated and issued.
306 Makes commands case-insensitive (but unfortunately does not alter command completion).
307 """
308 savestdout = sys.stdout
309 pipefilename = 'sqlpython.pipeline.tmp'
310 pipedCommands = pipeSeparator.separate(line)
311 if len(pipedCommands) > 1:
312 f = open(pipefilename,'w')
313 sys.stdout = f
314 self.precmd(pipedCommands[0])
315 self.onecmd(pipedCommands[0])
316 self.postcmd(False, pipedCommands[0])
317 f.close()
318 sys.stdout = savestdout
319 os.system('%s < %s' % (pipedCommands[1], pipefilename))
320 try:
321 args = line.split(None,1)
322 args[0] = args[0].lower()
323 return ' '.join(args)
324 except Exception:
325 return line
326
327 def do_shortcuts(self,arg):
328 """Lists available first-character shortcuts
329 (i.e. '!dir' is equivalent to 'shell dir')"""
330 for (scchar, scto) in self.shortcuts.items():
331 print '%s: %s' % (scchar, scto)
332
333
334 def colnames(self):
335 return [d[0] for d in curs.description]
336
337 def sql_format_itm(self, itm, needsquotes):
338 if itm is None:
339 return 'NULL'
340 if needsquotes:
341 return "'%s'" % str(itm)
342 return str(itm)
343 def output_as_insert_statements(self):
344 usequotes = [d[1] != cx_Oracle.NUMBER for d in self.curs.description]
345 def formatRow(row):
346 return ','.join(self.sql_format_itm(itm, useq)
347 for (itm, useq) in zip(row, usequotes))
348 result = ['INSERT INTO %s (%s) VALUES (%s);' %
349 (self.tblname, ','.join(self.colnames), formatRow(row))
350 for row in self.rows]
351 return '\n'.join(result)
352
353 def output_row_as_xml(self, row):
354 result = [' <%s>\n %s\n </%s>' %
355 (colname.lower(), str('' if (itm is None) else itm), colname.lower())
356 for (itm, colname) in zip(row, self.colnames)]
357 return '\n'.join(result)
358 def output_as_xml(self):
359 result = ['<%s>\n%s\n</%s>' %
360 (self.tblname, self.output_row_as_xml(row), self.tblname)
361 for row in self.rows]
362 return '\n'.join(result)
363
364 def output_as_html_table(self):
365 result = ''.join('<th>%s</th>' % c for c in self.colnames)
366 result = [' <tr>\n %s\n </tr>' % result]
367 for row in self.rows:
368 result.append(' <tr>\n %s\n </tr>' %
369 (''.join('<td>%s</td>' %
370 str('' if (itm is None) else itm)
371 for itm in row)))
372 result = '''<table id="%s">
373 %s
374 </table>''' % (self.tblname, '\n'.join(result))
375 return '\n'.join(result)
376
377 def output_as_list(self, align):
378 result = []
379 colnamelen = max(len(colname) for colname in self.colnames) + 1
380 for (idx, row) in enumerate(self.rows):
381 result.append('\n**** Row: %d' % (idx+1))
382 for (itm, colname) in zip(row, self.colnames):
383 if align:
384 colname = colname.ljust(colnamelen)
385 result.append('%s: %s' % (colname, itm))
386 return '\n'.join(result)
387
388 tableNameFinder = re.compile(r'from\s+([\w$#_"]+)', re.IGNORECASE | re.MULTILINE | re.DOTALL)
389 def output(self, outformat, rowlimit):
390 self.tblname = self.tableNameFinder.search(self.curs.statement).group(1)
391 self.colnames = [d[0] for d in self.curs.description]
392 if outformat == '\\i':
393 result = self.output_as_insert_statements()
394 elif outformat == '\\x':
395 result = self.output_as_xml()
396 elif outformat == '\\g':
397 result = self.output_as_list(align=False)
398 elif outformat == '\\G':
399 result = self.output_as_list(align=True)
400 elif outformat in ('\\s', '\\S', '\\c', '\\C'): #csv
401 result = []
402 if outformat in ('\\s', '\\c'):
403 result.append(','.join('"%s"' % colname for colname in self.colnames))
404 for row in self.rows:
405 result.append(','.join('"%s"' % ('' if itm is None else itm) for itm in row))
406 result = '\n'.join(result)
407 elif outformat == '\\h':
408 result = self.output_as_html_table()
409 else:
410 result = sqlpython.pmatrix(self.rows, self.curs.description, self.maxfetch)
411 return result
412
413 def do_select(self, arg, bindVarsIn=None):
414 """Fetch rows from a table.
415
416 Limit the number of rows retrieved by appending
417 an integer after the terminator
418 (example: SELECT * FROM mytable;10 )
419
420 Output may be formatted by choosing an alternative terminator
421 ("help terminators" for details)
422 """
423 bindVarsIn = bindVarsIn or {}
424 stmt = sqlpython.Statement('select '+arg)
425 self.query = stmt.query
426 if stmt.outformat == '\\t':
427 self.do_tselect(' '.join(self.query.split()[1:]) + ';', stmt.rowlimit)
428 else:
429 try:
430 self.varsUsed = findBinds(self.query, self.binds, bindVarsIn)
431 self.curs.execute(self.query, self.varsUsed)
432 self.rows = self.curs.fetchmany(min(self.maxfetch, (stmt.rowlimit or self.maxfetch)))
433 self.desc = self.curs.description
434 self.rc = self.curs.rowcount
435 if self.rc > 0:
436 print '\n' + self.output(stmt.outformat, stmt.rowlimit)
437 if self.rc == 0:
438 print '\nNo rows Selected.\n'
439 elif self.rc == 1:
440 print '\n1 row selected.\n'
441 if self.autobind:
442 self.binds.update(dict(zip([d[0] for d in self.desc], self.rows[0])))
443 elif self.rc < self.maxfetch:
444 print '\n%d rows selected.\n' % self.rc
445 else:
446 print '\nSelected Max Num rows (%d)' % self.rc
447 except Exception, e:
448 print e
449 import traceback
450 traceback.print_exc(file=sys.stdout)
451 self.sqlBuffer.append(self.query)
452
453 def showParam(self, param):
454 param = param.strip().lower()
455 if param in self.settable:
456 val = getattr(self, param)
457 print '%s: %s' % (param, str(getattr(self, param)))
458
459 def do_show(self, arg):
460 'Shows value of a (sqlpython, not ORACLE) parameter'
461 arg = arg.strip().lower()
462 if arg:
463 self.showParam(arg)
464 else:
465 for param in self.settable:
466 self.showParam(param)
467
468 def cast(self, current, new):
469 typ = type(current)
470 if typ == bool:
471 new = new.lower()
472 try:
473 if (new=='on') or (new[0] in ('y','t')):
474 return True
475 return False
476 except TypeError:
477 None
478 try:
479 return typ(new)
480 except:
481 print "Problem setting parameter (now %s) to %s; incorrect type?" % (current, new)
482 return current
483
484 def do_set(self, arg):
485 'Sets a (sqlpython, not ORACLE) parameter'
486 try:
487 paramName, val = arg.split(None, 1)
488 except Exception:
489 self.do_show(arg)
490 return
491 paramName = paramName.lower()
492 try:
493 current = getattr(self, paramName)
494 if callable(current):
495 raise NotSettableError
496 except (AttributeError, NotSettableError):
497 self.fail('set %s' % arg)
498 return
499 val = self.cast(current, val.strip(';'))
500 print paramName, ' - was: ', current
501 setattr(self, paramName.lower(), val)
502 print 'now: ', val
503
504 def do_describe(self, arg):
505 "emulates SQL*Plus's DESCRIBE"
506 object_type, owner, object_name = self.resolve(arg.strip(self.terminator).upper())
507 print "%s %s.%s" % (object_type, owner, object_name)
508 if object_type in ('TABLE','VIEW'):
509 self.do_select(queries['descTable'],{'object_name':object_name, 'owner':owner})
510 elif object_type == 'PACKAGE':
511 self.curs.execute(queries['PackageObjects'], {'package_name':object_name, 'owner':owner})
512 for (packageObj_name,) in self.curs:
513 print packageObj_name
514 self.do_select(queries['PackageObjArgs'],{'package_name':object_name, 'owner':owner, 'object_name':packageObj_name})
515 else:
516 self.do_select(queries['descProcedure'],{'owner':owner, 'object_name':object_name})
517 do_desc = do_describe
518
519 def do_comments(self, arg):
520 'Prints comments on a table and its columns.'
521 object_type, owner, object_name = self.resolve(arg.strip(self.terminator).upper())
522 if object_type:
523 self.curs.execute(queries['tabComments'],{'table_name':object_name, 'owner':owner})
524 print "%s %s.%s: %s" % (object_type, owner, object_name, self.curs.fetchone()[0])
525 self.do_select(queries['colComments'],{'owner':owner, 'object_name': object_name})
526
527 def resolve(self, identifier):
528 """Checks (my objects).name, (my synonyms).name, (public synonyms).name
529 to resolve a database object's name. """
530 parts = identifier.split('.')
531 try:
532 if len(parts) == 2:
533 owner, object_name = parts
534 self.curs.execute('SELECT object_type FROM all_objects WHERE owner = :owner AND object_name = :object_name',
535 {'owner': owner, 'object_name': object_name})
536 object_type = self.curs.fetchone()[0]
537 elif len(parts) == 1:
538 object_name = parts[0]
539 self.curs.execute(queries['resolve'], {'objName':object_name})
540 object_type, object_name, owner = self.curs.fetchone()
541 except TypeError:
542 print 'Could not resolve object %s.' % identifier
543 object_type, owner, object_name = '', '', ''
544 return object_type, owner, object_name
545
546 def do_shell(self, arg):
547 'execute a command as if at the OS prompt.'
548 os.system(arg)
549
550 def spoolstop(self):
551 if self.spoolFile:
552 sys.stdout = self.stdoutBeforeSpool
553 print 'Finished spooling to ', self.spoolFile.name
554 self.spoolFile.close()
555 self.spoolFile = None
556
557 def do_spool(self, arg):
558 """spool [filename] - begins redirecting output to FILENAME."""
559 self.spoolstop()
560 arg = arg.strip()
561 if not arg:
562 arg = 'output.lst'
563 if arg.lower() != 'off':
564 if '.' not in arg:
565 arg = '%s.lst' % arg
566 print 'Sending output to %s (until SPOOL OFF received)' % (arg)
567 self.spoolFile = open(arg, 'w')
568 sys.stdout = self.spoolFile
569
570 def write(self, arg, fname):
571 originalOut = sys.stdout
572 f = open(fname, 'w')
573 sys.stdout = f
574 self.onecmd(arg)
575 f.close()
576 sys.stdout = originalOut
577
578 def do_write(self, args):
579 'write [filename.extension] query - writes result to a file'
580 words = args.split(None, 1)
581 if len(words) > 1 and '.' in words[0]:
582 fname, command = words
583 else:
584 fname, command = 'output.txt', args
585 self.write(command, fname)
586 print 'Results written to %s' % os.path.join(os.getcwd(), fname)
587
588 def do_compare(self, args):
589 """COMPARE query1 TO query2 - uses external tool to display differences.
590
591 Sorting is recommended to avoid false hits."""
592 fnames = []
593 args2 = args.split(' to ')
594 for n in range(len(args2)):
595 query = args2[n]
596 fnames.append('compare%s.txt' % n)
597 if query.rstrip()[-1] != self.terminator:
598 query = '%s%s' % (query, self.terminator)
599 self.write(query, fnames[n])
600 diffMergeSearcher.invoke(fnames[0], fnames[1])
601
602 bufferPosPattern = re.compile('\d+')
603 rangeIndicators = ('-',':')
604 def bufferPositions(self, arg):
605 if not self.sqlBuffer:
606 return []
607 arg = arg.strip(self.terminator)
608 arg = arg.strip()
609 if not arg:
610 return [0]
611 arg = arg.strip().lower()
612 if arg in ('*', 'all', '-', ':'):
613 return range(len(self.sqlBuffer))
614
615 edges = [e for e in self.bufferPosPattern.findall(arg)]
616 edges = [int(e) for e in edges]
617 if len(edges) > 1:
618 edges = edges[:2]
619 else:
620 if arg[0] in self.rangeIndicators or arg[-1] in self.rangeIndicators:
621 edges.append(0)
622 edges.sort()
623 start = max(edges[0], 0)
624 end = min(edges[-1], len(self.sqlBuffer)-1)
625 return range(start, end+1)
626 def do_run(self, arg):
627 'run [N]: runs the SQL that was run N commands ago'
628 for pos in self.bufferPositions(arg):
629 self.onecmd(self.sqlBuffer[-1-pos])
630 def do_list(self, arg):
631 'list [N]: lists the SQL that was run N commands ago'
632 for pos in self.bufferPositions(arg):
633 print '*** %i statements ago ***' % pos
634 print self.sqlBuffer[-1-pos]
635 def load(self, fname):
636 """Pulls command(s) into sql buffer. Returns number of commands loaded."""
637 initialLength = len(self.sqlBuffer)
638 try:
639 f = open(fname, 'r')
640 except IOError, e:
641 try:
642 f = open('%s.sql' % fname, 'r')
643 except:
644 print 'Problem opening file %s: \n%s' % (fname, e)
645 return 0
646 txt = f.read()
647 f.close()
648 self.sqlBuffer.extend(commandSeparator.separate(txt))
649 return len(self.sqlBuffer) - initialLength
650 def do_ed(self, arg):
651 'ed [N]: brings up SQL from N commands ago in text editor, and puts result in SQL buffer.'
652 fname = 'mysqlpy_temp.sql'
653 try:
654 buffer = self.sqlBuffer[-1 - (int(arg or 0))]
655 except IndexError:
656 buffer = ''
657 f = open(fname, 'w')
658 f.write(buffer)
659 f.close()
660 editSearcher.invoke(fname)
661 self.load(fname)
662 do_edit = do_ed
663 def do_get(self, fname):
664 'Brings SQL commands from a file to the in-memory SQL buffer.'
665 commandsLoaded = self.load(fname)
666 if commandsLoaded:
667 self.do_list('1-%d' % (commandsLoaded-1))
668 def do_getrun(self, fname):
669 'Brings SQL commands from a file to the in-memory SQL buffer, and executes them.'
670 commandNums = range(self.load(fname))
671 commandNums.reverse()
672 for commandNum in commandNums:
673 self.do_run(str(commandNum))
674 self.sqlBuffer.pop()
675 def do_psql(self, arg):
676 '''Shortcut commands emulating psql's backslash commands.
677
678 \c connect
679 \d desc
680 \e edit
681 \g run
682 \h help
683 \i getrun
684 \o spool
685 \p list
686 \w save
687 \? help psql'''
688 commands = {}
689 for c in self.do_psql.__doc__.splitlines()[2:]:
690 (abbrev, command) = c.split(None, 1)
691 commands[abbrev[1:]] = command
692 words = arg.split(None,1)
693 abbrev = words[0]
694 try:
695 args = words[1]
696 except IndexError:
697 args = ''
698 try:
699 self.onecmd('%s %s' % (commands[abbrev], args))
700 self.onecmd('q')
701 except KeyError:
702 print 'psql command \%s not yet supported.' % abbrev
703 def do_save(self, fname):
704 'save FILENAME: Saves most recent SQL command to disk.'
705 f = open(fname, 'w')
706 f.write(self.sqlBuffer[-1])
707 f.close()
708
709 def do_print(self, arg):
710 'print VARNAME: Show current value of bind variable VARNAME.'
711 if arg:
712 if arg[0] == ':':
713 arg = arg[1:]
714 try:
715 print self.binds[arg]
716 except KeyError:
717 print 'No bind variable ', arg
718 else:
719 self.do_setbind('')
720 def do_setbind(self, arg):
721 args = arg.split(None, 2)
722 if len(args) == 0:
723 for (var, val) in self.binds.items():
724 print ':%s = %s' % (var, val)
725 elif len(args) == 1:
726 try:
727 print ':%s = %s' % (args[0], self.binds[args[0]])
728 except KeyError, e:
729 print noSuchBindMsg % args[0]
730 elif len(args) > 2 and args[1] in ('=',':='):
731 var, val = args[0], args[2]
732 if val[0] == val[-1] == "'" and len(val) > 1:
733 val = val[1:-1]
734 self.binds[var] = val
735 else:
736 print 'Could not parse ', args
737 def do_exec(self, arg):
738 if arg[0] == ':':
739 self.do_setbind(arg[1:])
740 else:
741 self.default('exec %s' % arg)
742
743 def _test():
744 import doctest
745 doctest.testmod()
746
747 if __name__ == "__main__":
748 "Silent return implies that all unit tests succeeded. Use -v to see details."
749 _test()