summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2013-10-20 16:59:56 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2013-10-20 16:59:56 -0400
commit4663ec98b226a7d495846f0d89c646110705bb30 (patch)
treef9eaab0e77a2aced7fce73014661cb8f757060e7 /lib/sqlalchemy/sql
parent0b0764b62ba87bdec41d0fc86618f3779cb4e3f0 (diff)
downloadsqlalchemy-4663ec98b226a7d495846f0d89c646110705bb30.tar.gz
- The typing system now handles the task of rendering "literal bind" values,
e.g. values that are normally bound parameters but due to context must be rendered as strings, typically within DDL constructs such as CHECK constraints and indexes (note that "literal bind" values become used by DDL as of :ticket:`2742`). A new method :meth:`.TypeEngine.literal_processor` serves as the base, and :meth:`.TypeDecorator.process_literal_param` is added to allow wrapping of a native literal rendering method. [ticket:2838] - enhance _get_colparams so that we can send flags like literal_binds into INSERT statements - add support in PG for inspecting standard_conforming_strings - add a new series of roundtrip tests based on INSERT of literal plus SELECT for basic literal rendering in dialect suite
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/compiler.py87
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py27
-rw-r--r--lib/sqlalchemy/sql/type_api.py73
3 files changed, 141 insertions, 46 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 22906af54..5c7a29f99 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -827,7 +827,7 @@ class SQLCompiler(Compiled):
@util.memoized_property
def _like_percent_literal(self):
- return elements.literal_column("'%'", type_=sqltypes.String())
+ return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE)
def visit_contains_op_binary(self, binary, operator, **kw):
binary = binary._clone()
@@ -871,39 +871,49 @@ class SQLCompiler(Compiled):
def visit_like_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
+
+ # TODO: use ternary here, not "and"/ "or"
return '%s LIKE %s' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
- + (escape and
- (' ESCAPE ' + self.render_literal_value(escape, None))
- or '')
+ + (
+ ' ESCAPE ' +
+ self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape else ''
+ )
def visit_notlike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
return '%s NOT LIKE %s' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
- + (escape and
- (' ESCAPE ' + self.render_literal_value(escape, None))
- or '')
+ + (
+ ' ESCAPE ' +
+ self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape else ''
+ )
def visit_ilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
return 'lower(%s) LIKE lower(%s)' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
- + (escape and
- (' ESCAPE ' + self.render_literal_value(escape, None))
- or '')
+ + (
+ ' ESCAPE ' +
+ self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape else ''
+ )
def visit_notilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
return 'lower(%s) NOT LIKE lower(%s)' % (
binary.left._compiler_dispatch(self, **kw),
binary.right._compiler_dispatch(self, **kw)) \
- + (escape and
- (' ESCAPE ' + self.render_literal_value(escape, None))
- or '')
+ + (
+ ' ESCAPE ' +
+ self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape else ''
+ )
def visit_bindparam(self, bindparam, within_columns_clause=False,
literal_binds=False,
@@ -954,9 +964,6 @@ class SQLCompiler(Compiled):
def render_literal_bindparam(self, bindparam, **kw):
value = bindparam.value
- processor = bindparam.type._cached_bind_processor(self.dialect)
- if processor:
- value = processor(value)
return self.render_literal_value(value, bindparam.type)
def render_literal_value(self, value, type_):
@@ -969,22 +976,10 @@ class SQLCompiler(Compiled):
of the DBAPI.
"""
- if isinstance(value, util.string_types):
- value = value.replace("'", "''")
- return "'%s'" % value
- elif value is None:
- return "NULL"
- elif isinstance(value, (float, ) + util.int_types):
- return repr(value)
- elif isinstance(value, decimal.Decimal):
- return str(value)
- elif isinstance(value, util.binary_type):
- # only would occur on py3k b.c. on 2k the string_types
- # directive above catches this.
- # see #2838
- value = value.decode(self.dialect.encoding).replace("'", "''")
- return "'%s'" % value
+ processor = type_._cached_literal_processor(self.dialect)
+ if processor:
+ return processor(value)
else:
raise NotImplementedError(
"Don't know how to literal-quote value %r" % value)
@@ -1599,7 +1594,7 @@ class SQLCompiler(Compiled):
def visit_insert(self, insert_stmt, **kw):
self.isinsert = True
- colparams = self._get_colparams(insert_stmt)
+ colparams = self._get_colparams(insert_stmt, **kw)
if not colparams and \
not self.dialect.supports_default_values and \
@@ -1732,7 +1727,7 @@ class SQLCompiler(Compiled):
table_text = self.update_tables_clause(update_stmt, update_stmt.table,
extra_froms, **kw)
- colparams = self._get_colparams(update_stmt, extra_froms)
+ colparams = self._get_colparams(update_stmt, extra_froms, **kw)
if update_stmt._hints:
dialect_hints = dict([
@@ -1801,7 +1796,7 @@ class SQLCompiler(Compiled):
bindparam._is_crud = True
return bindparam._compiler_dispatch(self)
- def _get_colparams(self, stmt, extra_tables=None):
+ def _get_colparams(self, stmt, extra_tables=None, **kw):
"""create a set of tuples representing column/string pairs for use
in an INSERT or UPDATE statement.
@@ -1853,9 +1848,9 @@ class SQLCompiler(Compiled):
# add it to values() in an "as-is" state,
# coercing right side to bound param
if elements._is_literal(v):
- v = self.process(elements.BindParameter(None, v, type_=k.type))
+ v = self.process(elements.BindParameter(None, v, type_=k.type), **kw)
else:
- v = self.process(v.self_group())
+ v = self.process(v.self_group(), **kw)
values.append((k, v))
@@ -1903,7 +1898,7 @@ class SQLCompiler(Compiled):
c, value, required=value is REQUIRED)
else:
self.postfetch.append(c)
- value = self.process(value.self_group())
+ value = self.process(value.self_group(), **kw)
values.append((c, value))
# determine tables which are actually
# to be updated - process onupdate and
@@ -1915,7 +1910,7 @@ class SQLCompiler(Compiled):
elif c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
- (c, self.process(c.onupdate.arg.self_group()))
+ (c, self.process(c.onupdate.arg.self_group(), **kw))
)
self.postfetch.append(c)
else:
@@ -1941,14 +1936,14 @@ class SQLCompiler(Compiled):
)
elif c.primary_key and implicit_returning:
self.returning.append(c)
- value = self.process(value.self_group())
+ value = self.process(value.self_group(), **kw)
elif implicit_return_defaults and \
c in implicit_return_defaults:
self.returning.append(c)
- value = self.process(value.self_group())
+ value = self.process(value.self_group(), **kw)
else:
self.postfetch.append(c)
- value = self.process(value.self_group())
+ value = self.process(value.self_group(), **kw)
values.append((c, value))
elif self.isinsert:
@@ -1966,13 +1961,13 @@ class SQLCompiler(Compiled):
if self.dialect.supports_sequences and \
(not c.default.optional or \
not self.dialect.sequences_optional):
- proc = self.process(c.default)
+ proc = self.process(c.default, **kw)
values.append((c, proc))
self.returning.append(c)
elif c.default.is_clause_element:
values.append(
(c,
- self.process(c.default.arg.self_group()))
+ self.process(c.default.arg.self_group(), **kw))
)
self.returning.append(c)
else:
@@ -2000,7 +1995,7 @@ class SQLCompiler(Compiled):
if self.dialect.supports_sequences and \
(not c.default.optional or \
not self.dialect.sequences_optional):
- proc = self.process(c.default)
+ proc = self.process(c.default, **kw)
values.append((c, proc))
if implicit_return_defaults and \
c in implicit_return_defaults:
@@ -2009,7 +2004,7 @@ class SQLCompiler(Compiled):
self.postfetch.append(c)
elif c.default.is_clause_element:
values.append(
- (c, self.process(c.default.arg.self_group()))
+ (c, self.process(c.default.arg.self_group(), **kw))
)
if implicit_return_defaults and \
@@ -2037,7 +2032,7 @@ class SQLCompiler(Compiled):
if c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
- (c, self.process(c.onupdate.arg.self_group()))
+ (c, self.process(c.onupdate.arg.self_group(), **kw))
)
if implicit_return_defaults and \
c in implicit_return_defaults:
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 1d7dacb91..01d918120 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -154,6 +154,12 @@ class String(Concatenable, TypeEngine):
self.unicode_error = unicode_error
self._warn_on_bytestring = _warn_on_bytestring
+ def literal_processor(self, dialect):
+ def process(value):
+ value = value.replace("'", "''")
+ return "'%s'" % value
+ return process
+
def bind_processor(self, dialect):
if self.convert_unicode or dialect.convert_unicode:
if dialect.supports_unicode_binds and \
@@ -345,6 +351,11 @@ class Integer(_DateAffinity, TypeEngine):
def python_type(self):
return int
+ def literal_processor(self, dialect):
+ def process(value):
+ return str(value)
+ return process
+
@util.memoized_property
def _expression_adaptations(self):
# TODO: need a dictionary object that will
@@ -481,6 +492,11 @@ class Numeric(_DateAffinity, TypeEngine):
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
+ def literal_processor(self, dialect):
+ def process(value):
+ return str(value)
+ return process
+
@property
def python_type(self):
if self.asdecimal:
@@ -728,6 +744,12 @@ class _Binary(TypeEngine):
def __init__(self, length=None):
self.length = length
+ def literal_processor(self, dialect):
+ def process(value):
+ value = value.decode(self.dialect.encoding).replace("'", "''")
+ return "'%s'" % value
+ return process
+
@property
def python_type(self):
return util.binary_type
@@ -1500,6 +1522,11 @@ class NullType(TypeEngine):
_isnull = True
+ def literal_processor(self, dialect):
+ def process(value):
+ return "NULL"
+ return process
+
class Comparator(TypeEngine.Comparator):
def _adapt_expression(self, op, other_comparator):
if isinstance(other_comparator, NullType.Comparator) or \
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index 83b8ec570..698e17472 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -75,6 +75,19 @@ class TypeEngine(Visitable):
def copy_value(self, value):
return value
+ def literal_processor(self, dialect):
+ """Return a conversion function for processing literal values that are
+ to be rendered directly without using binds.
+
+ This function is used when the compiler makes use of the
+ "literal_binds" flag, typically used in DDL generation as well
+ as in certain scenarios where backends don't accept bound parameters.
+
+ .. versionadded:: 0.9.0
+
+ """
+ return None
+
def bind_processor(self, dialect):
"""Return a conversion function for processing bind values.
@@ -265,6 +278,16 @@ class TypeEngine(Visitable):
except KeyError:
return self._dialect_info(dialect)['impl']
+
+ def _cached_literal_processor(self, dialect):
+ """Return a dialect-specific literal processor for this type."""
+ try:
+ return dialect._type_memos[self]['literal']
+ except KeyError:
+ d = self._dialect_info(dialect)
+ d['literal'] = lp = d['impl'].literal_processor(dialect)
+ return lp
+
def _cached_bind_processor(self, dialect):
"""Return a dialect-specific bind processor for this type."""
@@ -673,6 +696,22 @@ class TypeDecorator(TypeEngine):
implementation."""
return getattr(self.impl, key)
+ def process_literal_param(self, value, dialect):
+ """Receive a literal parameter value to be rendered inline within
+ a statement.
+
+ This method is used when the compiler renders a
+ literal value without using binds, typically within DDL
+ such as in the "server default" of a column or an expression
+ within a CHECK constraint.
+
+ The returned string will be rendered into the output string.
+
+ .. versionadded:: 0.9.0
+
+ """
+ raise NotImplementedError()
+
def process_bind_param(self, value, dialect):
"""Receive a bound parameter value to be converted.
@@ -737,6 +776,40 @@ class TypeDecorator(TypeEngine):
return self.__class__.process_bind_param.__code__ \
is not TypeDecorator.process_bind_param.__code__
+ @util.memoized_property
+ def _has_literal_processor(self):
+ """memoized boolean, check if process_literal_param is implemented.
+
+
+ """
+
+ return self.__class__.process_literal_param.__code__ \
+ is not TypeDecorator.process_literal_param.__code__
+
+ def literal_processor(self, dialect):
+ """Provide a literal processing function for the given
+ :class:`.Dialect`.
+
+ Subclasses here will typically override :meth:`.TypeDecorator.process_literal_param`
+ instead of this method directly.
+
+ .. versionadded:: 0.9.0
+
+ """
+ if self._has_literal_processor:
+ process_param = self.process_literal_param
+ impl_processor = self.impl.literal_processor(dialect)
+ if impl_processor:
+ def process(value):
+ return impl_processor(process_param(value, dialect))
+ else:
+ def process(value):
+ return process_param(value, dialect)
+
+ return process
+ else:
+ return self.impl.literal_processor(dialect)
+
def bind_processor(self, dialect):
"""Provide a bound value processing function for the
given :class:`.Dialect`.