changeset 322:d791333902f7

handle NULLs and apostrophes in \i inserts
author Catherine Devlin <catherine.devlin@gmail.com>
date Wed, 01 Apr 2009 16:17:46 -0400
parents 0c83ddee3a5c
children b97f1b8cdecd
files sqlpython/output_templates.py sqlpython/sqlpyPlus.py
diffstat 2 files changed, 12 insertions(+), 8 deletions(-) [+]
line wrap: on
line diff
--- a/sqlpython/output_templates.py	Wed Apr 01 14:48:05 2009 -0400
+++ b/sqlpython/output_templates.py	Wed Apr 01 16:17:46 2009 -0400
@@ -61,7 +61,7 @@
 {% end %}{% end %}"""),
 
 '\\i': genshi.template.NewTextTemplate("""{% for (rowNum, row) in enumerate(rows) %}
-INSERT INTO $tblname (${', '.join(colnames)}) VALUES (${', '.join(f % r for (r,f) in zip(row, formatters))});{% end %}"""),
+INSERT INTO $tblname (${', '.join(colnames)}) VALUES (${', '.join(formattedForSql(r) for r in row)});{% end %}"""),
 
 '\\c': genshi.template.NewTextTemplate("""
 ${','.join(colnames)}{% for row in rows %}
--- a/sqlpython/sqlpyPlus.py	Wed Apr 01 14:48:05 2009 -0400
+++ b/sqlpython/sqlpyPlus.py	Wed Apr 01 16:17:46 2009 -0400
@@ -435,19 +435,23 @@
             print '%s: %s' % (scchar, scto)
 
     tableNameFinder = re.compile(r'from\s+([\w$#_"]+)', re.IGNORECASE | re.MULTILINE | re.DOTALL)          
-    inputStatementFormatters = {
-        cx_Oracle.STRING: "'%s'",
-        cx_Oracle.DATETIME: "TO_DATE('%s', 'YYYY-MM-DD HH24:MI:SS')"}
-    inputStatementFormatters[cx_Oracle.CLOB] = inputStatementFormatters[cx_Oracle.STRING]
-    inputStatementFormatters[cx_Oracle.TIMESTAMP] = inputStatementFormatters[cx_Oracle.DATETIME]                
+    def formattedForSql(self, datum):
+        if datum is None:
+            return 'NULL'
+        elif isinstance(datum, basestring):
+            return "'%s'" % datum
+        try:
+            return datum.strftime("TO_DATE('%Y-%m-%d %H:%M:%S', 'YYYY-MM-DD HH24:MI:SS')")
+        except AttributeError:
+            return str(datum)
+              
     def output(self, outformat, rowlimit):
         self.tblname = self.tableNameFinder.search(self.querytext).group(1)
         self.colnames = [d[0] for d in self.curs.description]
         if outformat in output_templates:
             self.colnamelen = max(len(colname) for colname in self.colnames)
             self.coltypes = [d[1] for d in self.curs.description]
-            self.formatters = [self.inputStatementFormatters.get(typ, '%s') for typ in self.coltypes]    
-            result = output_templates[outformat].generate(**self.__dict__)        
+            result = output_templates[outformat].generate(formattedForSql=self.formattedForSql, **self.__dict__)        
         elif outformat == '\\t': # transposed
             rows = [self.colnames]
             rows.extend(list(self.rows))