view sqlpython/dbapiext.py @ 524:96803d93b9ae

show fixed for pgsql
author Catherine Devlin <catherine.devlin@gmail.com>
date Thu, 18 Nov 2010 18:47:37 -0500
parents d7eaa2b0c25f
children
line wrap: on
line source

"""
An extention to DBAPI-2.0 for more easily building SQL statements.

This extension allows you to call a DBAPI Cursor's execute method with a string
that contains format specifiers for escaped and/or unescaped arguments.  Escaped
arguments are specified using `` %X `` or `` %S `` (capital X or capital S).
You can also mix positional and keyword arguments in the call, and this takes
advantage of the Python call syntax niceties.  Also, lists passed in as
parameters to be formatted are automatically detected and joined by commas (this
works for both unescaped and escaped parameters-- lists to be escaped have their
elements escaped individually).  In addition, if you pass in a dictionary
corresponding to an escaped formatting specifier, the dictionary is rendered as
a list of comma-separated <key> = <value> pairs, such as are suitable for an
INSERT statement.

For performance, the results of analysing and preparing the query is kept in a
cache and reused on subsequence calls, similarly to the re or struct library.

(This is intended to become a reference implementation for a proposal for an
extension to tbe DBAPI-2.0.)

.. note:: for now the transformation only works with DBAPIs that supports
          parametric arguments in the form of Python's syntax for now
          (e.g. psycopg2).  It could easily be extended to support other DBAPI
          syntaxes.

For more details and motivation, see the accompanying explanation document at
http://furius.ca/pubcode/pub/conf/common/lib/python/dbapiext.html

5-minute usage instructions:

  Run execute_f() with a cursor object and appropriate arguments::

    execute_f(cursor, ' SELECT %s FROM %(t)s WHERE id = %S ', cols, id, t=table)

  Ideally, we should be able to monkey-patch this method onto the cursor class
  of the DBAPI library (this may not be possible if it is an extension module).

  By default, the result of analyzing each query is cached automatically and
  reused on further invocations, to minimize the amount of analysis to be
  performed at runtime.  If you want to do this explicitly, first compile your
  query, and execute it later with the resulting object, e.g.::

    analq = qcompile(' SELECT %s FROM %(t)s WHERE id = %S ')
    ...
    analq.execute(cursor, cols, id, t=table)

**Note to developers: this module contains tests, if you make any changes,
please make sure to run and fix the tests.**


Also, a formatting specifier is provided for where clauses: ``%A``, which joins
its contained entries with ``AND``. The only accepted data types are list of
pairs or a dictionary. Maybe we could provide an OR version (``%A`` and
``%O``).


Future Work
===========

- We could provide a reduce() method on the QueryAnalyzer, that will apply the
  given parameters and save the calculated arguments for later use; This would
  allow us to apply queries using multiple calls, to fill in only certain
  parameters at a time.  This method would return a new QueryAnalyzer, albeit
  one that would contain some pre-cooked apply_kwds and delay_kwds to be
  accumulated to in the apply call.

- Provide a simple test function that would allow people to test their queries
  without having to create a TestCursor.


"""

# stdlib imports
import re
from itertools import starmap, imap
from StringIO import StringIO
from datetime import date, datetime
from itertools import izip, count
from pprint import pprint


__all__ = ('execute_f', 'qcompile', 'set_paramstyle', 'execute_obj')


class QueryAnalyzer(object):
    """
    Analyze and contain a query string in a way that we can quickly put it back
    together when given the actual arguments.  This object contains knowledge of
    which arguments are positional and keyword, and is able to conditionally
    apply escaping when necessary, and expand lists as well.

    This is meant to be kept around or cached for efficiency.
    """

    # Note: the last few formatting characters are extra, from us.
    re_fmt = '[#0 +-]?([0-9]+|\\*)?(\\.[0-9]*)?[hlL]?[diouxXeEfFgGcrsSAO]'

    regexp = re.compile('%%(\\(([a-zA-Z0-9_]+)\\))?(%s)' % re_fmt)

    def __init__(self, query, paramstyle=None):
        self.orig_query = query

        self.positional = []
        """List of positional arguments to be consumed later.  The list consists
        of keynames."""

        self.components = None
        "A sequence of strings or match objects."

        if paramstyle is None:
            paramstyle = _def_paramstyle
        self.paramstyle = paramstyle
        self.init_style(paramstyle)
        "The parameter style supported by the underlying DBAPI."

        self.analyze() # Initialize.

    def init_style(self, paramstyle):
        "Pre-calculate style-specific constants."
        if paramstyle == 'pyformat':
            self.style_fmt = '%%%%(%(name)s)s'
            self.style_argstype = dict
        elif paramstyle == 'named':
            self.style_fmt = ':%(name)s'
            self.style_argstype = dict
        elif paramstyle == 'qmark':
            self.style_fmt = '?'
            self.style_argstype = list
        elif paramstyle == 'format':
            self.style_fmt = '%%%%s'
            self.style_argstype = list
        elif paramstyle == 'numeric':
            self.style_fmt = ':%(no)d'
            self.style_argstype = list
        # Non-standard. For our modified Sybase (from 0.37).
        elif paramstyle == 'atnamed':
            self.style_fmt = '@%(name)s'
            self.style_argstype = dict
        else:
            raise ValueError(
                "Parameter style '%s' is not supported." % paramstyle)

    def analyze(self):
        query = self.orig_query

        poscount = count(1)

        comps = self.components = []
        for x in gensplit(self.regexp, query):
            if isinstance(x, (str, unicode)):
                comps.append(x)
            else:
                keyname, fmt = x.group(2, 3)
                if keyname is None:
                    keyname = '__p%d' % poscount.next()
                    self.positional.append(keyname)
                sep = ', '
                if fmt in 'XS':
                    fmt = 's'
                    escaped = True
                elif fmt in 'A':
                    fmt = 's'
                    escaped = True
                    sep = ' AND '
                elif fmt in 'O':
                    fmt = 's'
                    escaped = True
                    sep = ' OR '
                else:
                    escaped = False
                comps.append( (keyname, escaped, sep, fmt) )

    def __str__(self):
        """
        Return the string that would be used before application of the
        positional and keyword arguments.
        """
        style_fmt = self.style_fmt
        oss = StringIO()
        no = count(1)
        for x in self.components:
            if isinstance(x, (str, unicode)):
                oss.write(x)
            else:
                keyname, escaped, sep, fmt = x
                if escaped:
                    oss.write(style_fmt % {'name': keyname,
                                           'no': no.next()})
                else:
                    oss.write('%%(%s)%s' % (keyname, fmt))
        return oss.getvalue()

    def apply(self, *args, **kwds):
        if len(args) != len(self.positional):
            raise TypeError('not enough arguments for format string')

        # Merge the positional arguments in the keywords dict.
        for name, value in izip(self.positional, args):
            assert name not in kwds
            kwds[name] = value

        # Patch up the components into a string.
        listexpans = {} # cached list expansions.
        apply_kwds, delay_kwds = {}, self.style_argstype()

        no = count(1)
        style_fmt = self.style_fmt
        dict_fmt = '%%(key)s = %s' % style_fmt
        output = []
        for x in self.components:
            if isinstance(x, (str, unicode)):
                out = x
            else:
                keyname, escaped, sep, fmt = x

                # Split keyword lists.
                # Expand into lists of words.
                value = kwds[keyname]
                if isinstance(value, (tuple, list, set)):
                    try:
                        words = listexpans[keyname] # Try cache.
                    except KeyError:
                        # Compute list expansion.
                        words = ['%s_l%d__' % (keyname, x)
                                 for x in xrange(len(value))]
                        listexpans[keyname] = words

                    if escaped:
                        outfmt = [style_fmt %
                                  {'name': x, 'no': no.next()} for x in words]
                    else:
                        outfmt = ['%%(%s)%s' % (x, fmt) for x in words]

                elif isinstance(value, dict):
                    # If a dict is passed in, the format specified *must* be for
                    # escape; we detect this and raise an appropriate error.
                    if not escaped:
                        raise ValueError("Attempting to format a dict in "
                                         "an SQL statement without escaping.")

                    # Convert dict in a list of comma-separated 'name=value' pairs.
                    items = value.items()
                    words = ['%s_key_%s__' % (keyname, x[0]) for x in items]
                    value = [x[1] for x in items]
                    outfmt = [dict_fmt % {'key': k, 'name': word}
                              for word, (k, v) in izip(words, items)]

                else:
                    words, value = (keyname,), (value,)
                    if escaped:
                        outfmt = [style_fmt % {'name': keyname, 'no': no.next()}]
                    else:
                        outfmt = ['%%(%s)%s' % (keyname, fmt)]

                if escaped:
                    okwds = delay_kwds
                else:
                    okwds = apply_kwds

                # Dispatch values on the appropriate output dictionary.
                assert len(words) == len(value)
                if isinstance(okwds, dict):
                    okwds.update(izip(words, value))
                else:
                    okwds.extend(value)

                # Create formatting string.
                out = sep.join(outfmt)

            output.append(out)

        # Apply positional arguments, here, now.
        newquery = ''.join(output)

        # Return the string with the delayed arguments as formatting specifiers,
        # to be formatted by DBAPI, and the delayed arguments.
        return newquery % apply_kwds, delay_kwds

    def execute(self, cursor_, *args, **kwds):
        """
        Execute the analyzed query on the given cursor, with the given arguments
        and keywords.
        """
        # Translate this call into a compatible call to execute().
        cquery, ckwds = self.apply(*args, **kwds)

        # Execute the transformed query.
        return cursor_.execute(cquery, ckwds)


def gensplit(regexp, s):
    """
    Regexp-splitter generator.  Generates strings and match objects.
    """
    c = 0
    for mo in regexp.finditer(s):
        yield s[c:mo.start()]
        yield mo
        c = mo.end()
    yield s[c:]



_def_paramstyle = 'pyformat'

def set_paramstyle(style_or_dbapi):
    """
    Sets the default paramstyle to be used by the underlying DBAPI.
    You can pass in a DBAPI module object or a string. See PEP249 for details.
    """
    global _def_paramstyle
    if isinstance(style_or_dbapi, str):
        _def_paramstyle = style_or_dbapi
    else:
        _def_paramstyle = style_or_dbapi.paramstyle
    assert _def_paramstyle in ('qmark', 'numeric',
                               'named', 'format', 'pyformat')



qcompile = QueryAnalyzer
"""
Compile a query in a compatible query analyzer.
"""



# Query cache used to avoid having to analyze the same queries multiple times.
# Hashed on the query string.
_query_cache = {}

# Note: we use cursor_ and query_ because we often call this function with
# vars() which include those names on the caller side.
def execute_f(cursor_, query_, *args, **kwds):
    """
    Fancy execute method for a cursor.  (Note: this is implemented as a function
    but is really meant to be a method to replace or complement the standard
    method Cursor.execute() from DBAPI-2.0.)

    Convert fancy query arguments into a DBAPI-compatible set of arguments and
    execute.

    This method supports a different syntax than the DBAPI execute() method:

    - By default, %s placeholders are not escaped.

    - Use the %S or %(name)S placeholder to specify escaped strings.

    - You can specify positional arguments without having to place them in an
      extra tuple.

    - Keyword arguments are used as expected to fill in missing values.
      Positional arguments are used to fill non-keyword placeholders.

    - Arguments that are tuples, lists or sets will be automatically joined by colons.
      If the corresponding formatting is %S or %(name)S, the members of the
      sequence will be escaped individually.

    See qcompile() for details.

    Note that this function accepts a '_paramstyle' optional argument, to set
    which parameter style to use.
    """
    debug = debug_convert or kwds.pop('__debug__', None)
    if debug:
        print '\n' + '=' * 80
        print '\noriginal ='
        print query_
        print '\nargs ='
        pprint(args)
        print '\nkwds ='
        pprint(kwds)

    # Get the cached query analyzer or create one.
    try:
        q = _query_cache[query_]
    except KeyError:
        _query_cache[query_] = q = qcompile(
            query_,
            paramstyle=kwds.pop('paramstyle', None))

    if debug:
        print '\nquery analyzer =', str(q)

    # Translate this call into a compatible call to execute().
    cquery, ckwds = q.apply(*args, **kwds)

    if debug:
        print '\ntransformed ='
        print cquery
        print '\nnewkwds ='
        pprint(ckwds)

    # Execute the transformed query.
    return cursor_.execute(cquery, ckwds)


# Add support for ntuple wrapping (std in 2.6).
try:
    from collections import namedtuple

    # Patch from Catherine Devlin <catherine dot devlin at gmail dot com>:
    #
    #   "Column names with ``$`` and ``#`` are legal in SQL, but not in
    #   namedtuple field names. This throws exceptions when you try to
    #   execute_obj on queries with such column names. For the apps I write
    #   (rooting around in Oracle data dictionary views), there's no avoiding
    #   the ``$`` and ``#`` characters. Therefore, I added code to munge column
    #   names until they are namedtuple-legal. Another alternative would be to
    #   simply change the error message raised into something that would suggest
    #   that the user use column aliases in the SQL statement to change column
    #   names into something namedtuple-legal."  (2010-05-25)
    from collections import _iskeyword
    not_alphanumeric = re.compile('[^a-zA-Z0-9]')
    def rename_duplicates(lst, append_char = '_'):
        newlist = []
        for itm in lst:
            while itm in newlist:
                itm += append_char
            newlist.append(itm)
        return newlist
    def _fix_fieldname(fieldname):
        "Ensure that a field name will pass collection.namedtuple's criteria."
        fieldname = not_alphanumeric.sub('_', fieldname)
        while _iskeyword(fieldname):
            fieldname = fieldname + '_'
        return fieldname
    def ntuple(typename, field_names, verbose=False):
        field_names = [_fix_fieldname(fn) for fn in field_names.split()]
        field_names = rename_duplicates(field_names)
        return namedtuple(typename, ' '.join(field_names), verbose)

except ImportError:
    ntuple = None

if ntuple:
    from operator import itemgetter

    def execute_obj(conn, *args, **kwds):
        """
        Run a query on the given connection or cursor and yield ntuples of the
        results.  'curs' can be either a Connection or a Cursor object.
        """
        # Convert to a cursor if necessary.
        if re.search('Cursor', conn.__class__.__name__, re.I):
            curs = conn
        else:
            curs = conn.cursor()

        # Execute the query.
        execute_f(curs, *args, **kwds)

        # Yield all the results wrapped up in an ntuple.
        names = map(itemgetter(0), curs.description)
        TupleCls = ntuple('Row', ' '.join(names))
        return starmap(TupleCls, imap(tuple, curs))
else:
    execute_obj = None



#-------------------------------------------------------------------------------

class _TestCursor(object):
    """
    Fake cursor that fakes the escaped replacments like a real DBAPI cursor, but
    simply returns the final string.
    """
    execute_f = execute_f

    def execute(self, query, args):
        return self.render_fake(query, args).strip()

    @staticmethod
    def render_fake(query, kwds):
        """
        Take arguments as the DBAPI of execute() accepts and fake escaping the
        arguments as the DBAPI implementation would and return the resulting
        string.  This is used only for testing, to make testing easier and more
        intuitive, to view the completed queries without the replacement
        variables.
        """
        for key, value in kwds.items():
            if isinstance(value, type(None)):
                kwds[key] = 'NULL'
            elif isinstance(value, str):
                kwds[key] = repr(value)
            elif isinstance(value, unicode):
                kwds[key] = repr(value.encode('utf-8'))
            elif isinstance(value, (date, datetime)):
                kwds[key] = repr(value.isoformat())

        result = query % kwds

        if debug_convert:
            print '\n--- 5. after full replacement (fake dbapi application)'
            print result

        return result


def _multi2one(s):
    "Join a multi-line string in a single line."
    s = re.sub('[ \n]+', ' ', s).strip()
    return re.sub(', ', ',', s)


import unittest
class TestExtension(unittest.TestCase):
    """
    Tests for the extention functions.
    """
    def compare_nows(self, s1, s2):
        """
        Compare two strings without considering the whitespace.
        """
        s1 = _multi2one(s1)
        s2 = _multi2one(s2)
        self.assertEquals(s1, s2)

    def test_basic(self):
        "Basic replacement tests."

        cursor = _TestCursor()

        simple, isimple, seq = 'SIMPLE', 42, ('L1', 'L2', 'L3')
        for query, args, kwds, expect in (

            # With simple arguments.
            (' %s ', (simple,), dict(), " SIMPLE "),
            (' %S ', (simple,), dict(), " 'SIMPLE' "),
            (' %X ', (simple,), dict(), " 'SIMPLE' "),
            (' %d ', (isimple,), dict(), " 42 "),
            (' %(k)s ', (), dict(k=simple), " SIMPLE "),
            (' %(k)d ', (), dict(k=isimple), " 42 "),
            (' %(k)S ', (), dict(k=simple), " 'SIMPLE' "),
            (' %(k)X ', (), dict(k=simple), " 'SIMPLE' "),

            # Same but with lists.
            (' %s ', (seq,), dict(), " L1,L2,L3 "),
            (' %S ', (seq,), dict(), " 'L1','L2','L3' "),
            (' %X ', (seq,), dict(), " 'L1','L2','L3' "),
            (' %(k)s ', (), dict(k=seq), " L1,L2,L3 "),
            (' %(k)S ', (), dict(k=seq), " 'L1','L2','L3' "),
            (' %(k)X ', (), dict(k=seq), " 'L1','L2','L3' "),

            ):

            # Normal invocation.
            self.compare_nows(
                cursor.execute_f(query, *args, **kwds),
                expect)

            # Repeated destination formatting string.
            self.compare_nows(
                cursor.execute_f(query + query, *(args + args) , **kwds),
                expect + expect)


    def test_misc(self):

        d = date(2006, 07, 28)

        cursor = _TestCursor()

        self.compare_nows(
            cursor.execute_f('''
              INSERT INTO %(table)s (%s)
                SET VALUES (%S)
                WHERE id = %(id)S
                  AND name IN (%(name)S)
                  AND name NOT IN (%(name)S)
            ''',
                         ('col1', 'col2'),
                         (42, "bli"),
                         id="02351440-7b7e-4260",
                         name=[45, 56, 67, 78],
                         table='table'),
              """
              INSERT INTO table (col1, col2)
                SET VALUES (42, 'bli')
                WHERE id = '02351440-7b7e-4260'
                  AND name IN (45, 56, 67, 78)
                  AND name NOT IN (45, 56, 67, 78)
              """)


        # Note: this should fail in the old text.
        self.compare_nows(
            cursor.execute_f(''' %(id)s AND %(id)S ''',
                         id=['fulano', 'mengano']),
              """ fulano,mengano AND 'fulano','mengano' """)


        self.compare_nows(
            cursor.execute_f('''
              SELECT %s FROM %s WHERE id = %S
            ''',
                         ('id', 'name', 'title'), 'books',
                         '02351440-7b7e-4260'),
            """SELECT id,name,title FROM books
               WHERE id = '02351440-7b7e-4260'""")

        self.compare_nows(
            cursor.execute_f('''
           SELECT %s FROM %s WHERE id = %(id)S %(id)S
        ''', ('id', 'name', 'title'), 'books', id=d),
            """SELECT id,name,title FROM books
               WHERE id = '2006-07-28' '2006-07-28'""")

        self.compare_nows(
            cursor.execute_f(''' %(id)S %(id)S ''', id='02351440-7b7e-4260'),
            " '02351440-7b7e-4260' '02351440-7b7e-4260' ")

        self.compare_nows(
            cursor.execute_f(''' %s %(id)S %(id)s ''',
                         'books',
                         id='02351440-7b7e-4260'),
            "  books '02351440-7b7e-4260' 02351440-7b7e-4260  ")

        self.compare_nows(
            cursor.execute_f('''
              SELECT %s FROM %(table)s WHERE col1 = %S AND col2 < %(val)S
            ''', ('col1', 'col2', 'col3'), 'value1', table='my-table', val=42),
            """ SELECT col1,col2,col3 FROM my-table
                WHERE col1 = 'value1' AND col2 < 42 """)

        self.compare_nows(
            cursor.execute_f("""
              INSERT INTO thumbnails
                (basename, photo1, photo2, photo3)
                VALUES (%S, %S)
                """, 'PHOTONAME', ('BIN1', 'BIN2', 'BIN3')),
            """
              INSERT INTO thumbnails
                (basename, photo1, photo2, photo3)
                VALUES ('PHOTONAME', 'BIN1', 'BIN2', 'BIN3')
                """)


    def test_null(self):
        cursor = _TestCursor()
        self.compare_nows(
            cursor.execute_f('''
              INSERT INTO poodle (hair)
                SET VALUES (%S)
            ''', None),
              """
              INSERT INTO poodle (hair)
                SET VALUES (NULL)
              """)


    def test_paramstyles(self):

        d = date(2006, 07, 28)

        cursor = _TestCursor()

        query = '''
              Simple: %s  Escaped: %S
              Kwd: %(bli)s KwdEscaped: %(bli)S
            '''
        args = ('hansel', 'gretel')
        kwds = dict(bli='bethel')

        test_data = {
            'pyformat': ("""
              Simple: hansel  Escaped: %(__p2)s
              Kwd: bethel KwdEscaped: %(bli)s
            """, {'__p2': 'gretel', 'bli': 'bethel'}),

            'named': ("""
              Simple: hansel  Escaped: :__p2
              Kwd: bethel KwdEscaped: :bli
            """, {'__p2': 'gretel', 'bli': 'bethel'}),

            'qmark': ("""
              Simple: hansel  Escaped: ?
              Kwd: bethel KwdEscaped: ?
            """, ['gretel', 'bethel']),

            'format': ("""
              Simple: hansel  Escaped: %s
              Kwd: bethel KwdEscaped: %s
            """, ['gretel', 'bethel']),

            'numeric': ("""
              Simple: hansel  Escaped: :1
              Kwd: bethel KwdEscaped: :2
            """, ['gretel', 'bethel']),
            }

        for style, (estr, eargs) in test_data.iteritems():
            qstr, qargs = qcompile(query, paramstyle=style).apply(
                *args, **kwds)

            self.compare_nows(qstr, estr)
            self.assertEquals(qargs, eargs)

        # Visual debugging.
        print_it = 0
        for style in test_data.iterkeys():
            qanal = qcompile("""
              %S %(c1)S %S %S %(c2)S
            """, paramstyle=style)

            qstr, qargs = qanal.apply(1, 2, 3, c1='CC1', c2='CC2')
            if print_it:
                print qstr
                print qargs

    def test_dict(self):
        "Tests for passing in a dictionary argument."

        cursor = _TestCursor()
        data = {'brazil': 'portuguese',
                'peru': 'spanish',
                'japan': 'japanese',
                'philipines': 'tagalog'}

        self.assertRaises(ValueError, execute_f,
                          cursor, ' unescaped: %s ', data)

        res = execute_f(cursor, ' UPDATE %s SET %S; ', 'mytable', data)
        self.compare_nows(res, """
           UPDATE mytable
             SET brazil = 'portuguese',
                 japan = 'japanese',
                 philipines = 'tagalog',
                 peru = 'spanish';       """)

    def test_and(self):
        "Tests for passing in a dictionary argument."

        cursor = _TestCursor()
        keydata = {'udid': '11111111111111111111',
                   'imgid': 17}
        valuedata = {'rating': 9}

        self.assertRaises(ValueError, execute_f,
                          cursor, ' unescaped: %s ', keydata)

        res = execute_f(cursor, ' UPDATE %s SET %S WHERE %A; ', 'mytable',
                        valuedata, keydata)
        self.compare_nows(res, """
           UPDATE mytable
             SET rating = 9
             WHERE udid = '11111111111111111111' AND imgid = 17;
        """)

        res = execute_f(cursor, ' UPDATE %s SET %S WHERE %O; ', 'mytable',
                        valuedata, keydata)
        self.compare_nows(res, """
           UPDATE mytable
             SET rating = 9
             WHERE udid = '11111111111111111111' OR imgid = 17;
        """)


    def test_sqlite3(self):
        import sqlite3 as dbapi
        set_paramstyle(dbapi)
        conn = dbapi.connect(':memory:')
        curs = conn.cursor()
        execute_f(curs, """
           CREATE TABLE books (
              author TEXT,
              title TEXT,
              PRIMARY KEY (title)
           );
        """)
        execute_f(curs, """
           INSERT INTO books VALUES (%S);
        """, ("Tolstoy", "War and Peace"))

        execute_f(curs, """
           INSERT INTO books (author) VALUES (%S);
        """, "Dostoyesvki")


debug_convert = 0
if __name__ == '__main__':
    unittest.main() # or use nosetests