diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-10-20 16:59:56 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2013-10-20 16:59:56 -0400 |
| commit | 4663ec98b226a7d495846f0d89c646110705bb30 (patch) | |
| tree | f9eaab0e77a2aced7fce73014661cb8f757060e7 /lib/sqlalchemy/sql | |
| parent | 0b0764b62ba87bdec41d0fc86618f3779cb4e3f0 (diff) | |
| download | sqlalchemy-4663ec98b226a7d495846f0d89c646110705bb30.tar.gz | |
- The typing system now handles the task of rendering "literal bind" values,
e.g. values that are normally bound parameters but due to context must
be rendered as strings, typically within DDL constructs such as
CHECK constraints and indexes (note that "literal bind" values
become used by DDL as of :ticket:`2742`). A new method
:meth:`.TypeEngine.literal_processor` serves as the base, and
:meth:`.TypeDecorator.process_literal_param` is added to allow wrapping
of a native literal rendering method. [ticket:2838]
- enhance _get_colparams so that we can send flags like literal_binds into
INSERT statements
- add support in PG for inspecting standard_conforming_strings
- add a new series of roundtrip tests based on INSERT of literal plus SELECT
for basic literal rendering in dialect suite
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 87 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 27 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/type_api.py | 73 |
3 files changed, 141 insertions, 46 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 22906af54..5c7a29f99 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -827,7 +827,7 @@ class SQLCompiler(Compiled): @util.memoized_property def _like_percent_literal(self): - return elements.literal_column("'%'", type_=sqltypes.String()) + return elements.literal_column("'%'", type_=sqltypes.STRINGTYPE) def visit_contains_op_binary(self, binary, operator, **kw): binary = binary._clone() @@ -871,39 +871,49 @@ class SQLCompiler(Compiled): def visit_like_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) + + # TODO: use ternary here, not "and"/ "or" return '%s LIKE %s' % ( binary.left._compiler_dispatch(self, **kw), binary.right._compiler_dispatch(self, **kw)) \ - + (escape and - (' ESCAPE ' + self.render_literal_value(escape, None)) - or '') + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) def visit_notlike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) return '%s NOT LIKE %s' % ( binary.left._compiler_dispatch(self, **kw), binary.right._compiler_dispatch(self, **kw)) \ - + (escape and - (' ESCAPE ' + self.render_literal_value(escape, None)) - or '') + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) def visit_ilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) return 'lower(%s) LIKE lower(%s)' % ( binary.left._compiler_dispatch(self, **kw), binary.right._compiler_dispatch(self, **kw)) \ - + (escape and - (' ESCAPE ' + self.render_literal_value(escape, None)) - or '') + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) def visit_notilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) return 'lower(%s) NOT LIKE lower(%s)' % ( binary.left._compiler_dispatch(self, **kw), binary.right._compiler_dispatch(self, **kw)) \ - + (escape and - (' ESCAPE ' + self.render_literal_value(escape, None)) - or '') + + ( + ' ESCAPE ' + + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape else '' + ) def visit_bindparam(self, bindparam, within_columns_clause=False, literal_binds=False, @@ -954,9 +964,6 @@ class SQLCompiler(Compiled): def render_literal_bindparam(self, bindparam, **kw): value = bindparam.value - processor = bindparam.type._cached_bind_processor(self.dialect) - if processor: - value = processor(value) return self.render_literal_value(value, bindparam.type) def render_literal_value(self, value, type_): @@ -969,22 +976,10 @@ class SQLCompiler(Compiled): of the DBAPI. """ - if isinstance(value, util.string_types): - value = value.replace("'", "''") - return "'%s'" % value - elif value is None: - return "NULL" - elif isinstance(value, (float, ) + util.int_types): - return repr(value) - elif isinstance(value, decimal.Decimal): - return str(value) - elif isinstance(value, util.binary_type): - # only would occur on py3k b.c. on 2k the string_types - # directive above catches this. - # see #2838 - value = value.decode(self.dialect.encoding).replace("'", "''") - return "'%s'" % value + processor = type_._cached_literal_processor(self.dialect) + if processor: + return processor(value) else: raise NotImplementedError( "Don't know how to literal-quote value %r" % value) @@ -1599,7 +1594,7 @@ class SQLCompiler(Compiled): def visit_insert(self, insert_stmt, **kw): self.isinsert = True - colparams = self._get_colparams(insert_stmt) + colparams = self._get_colparams(insert_stmt, **kw) if not colparams and \ not self.dialect.supports_default_values and \ @@ -1732,7 +1727,7 @@ class SQLCompiler(Compiled): table_text = self.update_tables_clause(update_stmt, update_stmt.table, extra_froms, **kw) - colparams = self._get_colparams(update_stmt, extra_froms) + colparams = self._get_colparams(update_stmt, extra_froms, **kw) if update_stmt._hints: dialect_hints = dict([ @@ -1801,7 +1796,7 @@ class SQLCompiler(Compiled): bindparam._is_crud = True return bindparam._compiler_dispatch(self) - def _get_colparams(self, stmt, extra_tables=None): + def _get_colparams(self, stmt, extra_tables=None, **kw): """create a set of tuples representing column/string pairs for use in an INSERT or UPDATE statement. @@ -1853,9 +1848,9 @@ class SQLCompiler(Compiled): # add it to values() in an "as-is" state, # coercing right side to bound param if elements._is_literal(v): - v = self.process(elements.BindParameter(None, v, type_=k.type)) + v = self.process(elements.BindParameter(None, v, type_=k.type), **kw) else: - v = self.process(v.self_group()) + v = self.process(v.self_group(), **kw) values.append((k, v)) @@ -1903,7 +1898,7 @@ class SQLCompiler(Compiled): c, value, required=value is REQUIRED) else: self.postfetch.append(c) - value = self.process(value.self_group()) + value = self.process(value.self_group(), **kw) values.append((c, value)) # determine tables which are actually # to be updated - process onupdate and @@ -1915,7 +1910,7 @@ class SQLCompiler(Compiled): elif c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, self.process(c.onupdate.arg.self_group())) + (c, self.process(c.onupdate.arg.self_group(), **kw)) ) self.postfetch.append(c) else: @@ -1941,14 +1936,14 @@ class SQLCompiler(Compiled): ) elif c.primary_key and implicit_returning: self.returning.append(c) - value = self.process(value.self_group()) + value = self.process(value.self_group(), **kw) elif implicit_return_defaults and \ c in implicit_return_defaults: self.returning.append(c) - value = self.process(value.self_group()) + value = self.process(value.self_group(), **kw) else: self.postfetch.append(c) - value = self.process(value.self_group()) + value = self.process(value.self_group(), **kw) values.append((c, value)) elif self.isinsert: @@ -1966,13 +1961,13 @@ class SQLCompiler(Compiled): if self.dialect.supports_sequences and \ (not c.default.optional or \ not self.dialect.sequences_optional): - proc = self.process(c.default) + proc = self.process(c.default, **kw) values.append((c, proc)) self.returning.append(c) elif c.default.is_clause_element: values.append( (c, - self.process(c.default.arg.self_group())) + self.process(c.default.arg.self_group(), **kw)) ) self.returning.append(c) else: @@ -2000,7 +1995,7 @@ class SQLCompiler(Compiled): if self.dialect.supports_sequences and \ (not c.default.optional or \ not self.dialect.sequences_optional): - proc = self.process(c.default) + proc = self.process(c.default, **kw) values.append((c, proc)) if implicit_return_defaults and \ c in implicit_return_defaults: @@ -2009,7 +2004,7 @@ class SQLCompiler(Compiled): self.postfetch.append(c) elif c.default.is_clause_element: values.append( - (c, self.process(c.default.arg.self_group())) + (c, self.process(c.default.arg.self_group(), **kw)) ) if implicit_return_defaults and \ @@ -2037,7 +2032,7 @@ class SQLCompiler(Compiled): if c.onupdate is not None and not c.onupdate.is_sequence: if c.onupdate.is_clause_element: values.append( - (c, self.process(c.onupdate.arg.self_group())) + (c, self.process(c.onupdate.arg.self_group(), **kw)) ) if implicit_return_defaults and \ c in implicit_return_defaults: diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 1d7dacb91..01d918120 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -154,6 +154,12 @@ class String(Concatenable, TypeEngine): self.unicode_error = unicode_error self._warn_on_bytestring = _warn_on_bytestring + def literal_processor(self, dialect): + def process(value): + value = value.replace("'", "''") + return "'%s'" % value + return process + def bind_processor(self, dialect): if self.convert_unicode or dialect.convert_unicode: if dialect.supports_unicode_binds and \ @@ -345,6 +351,11 @@ class Integer(_DateAffinity, TypeEngine): def python_type(self): return int + def literal_processor(self, dialect): + def process(value): + return str(value) + return process + @util.memoized_property def _expression_adaptations(self): # TODO: need a dictionary object that will @@ -481,6 +492,11 @@ class Numeric(_DateAffinity, TypeEngine): def get_dbapi_type(self, dbapi): return dbapi.NUMBER + def literal_processor(self, dialect): + def process(value): + return str(value) + return process + @property def python_type(self): if self.asdecimal: @@ -728,6 +744,12 @@ class _Binary(TypeEngine): def __init__(self, length=None): self.length = length + def literal_processor(self, dialect): + def process(value): + value = value.decode(self.dialect.encoding).replace("'", "''") + return "'%s'" % value + return process + @property def python_type(self): return util.binary_type @@ -1500,6 +1522,11 @@ class NullType(TypeEngine): _isnull = True + def literal_processor(self, dialect): + def process(value): + return "NULL" + return process + class Comparator(TypeEngine.Comparator): def _adapt_expression(self, op, other_comparator): if isinstance(other_comparator, NullType.Comparator) or \ diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 83b8ec570..698e17472 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -75,6 +75,19 @@ class TypeEngine(Visitable): def copy_value(self, value): return value + def literal_processor(self, dialect): + """Return a conversion function for processing literal values that are + to be rendered directly without using binds. + + This function is used when the compiler makes use of the + "literal_binds" flag, typically used in DDL generation as well + as in certain scenarios where backends don't accept bound parameters. + + .. versionadded:: 0.9.0 + + """ + return None + def bind_processor(self, dialect): """Return a conversion function for processing bind values. @@ -265,6 +278,16 @@ class TypeEngine(Visitable): except KeyError: return self._dialect_info(dialect)['impl'] + + def _cached_literal_processor(self, dialect): + """Return a dialect-specific literal processor for this type.""" + try: + return dialect._type_memos[self]['literal'] + except KeyError: + d = self._dialect_info(dialect) + d['literal'] = lp = d['impl'].literal_processor(dialect) + return lp + def _cached_bind_processor(self, dialect): """Return a dialect-specific bind processor for this type.""" @@ -673,6 +696,22 @@ class TypeDecorator(TypeEngine): implementation.""" return getattr(self.impl, key) + def process_literal_param(self, value, dialect): + """Receive a literal parameter value to be rendered inline within + a statement. + + This method is used when the compiler renders a + literal value without using binds, typically within DDL + such as in the "server default" of a column or an expression + within a CHECK constraint. + + The returned string will be rendered into the output string. + + .. versionadded:: 0.9.0 + + """ + raise NotImplementedError() + def process_bind_param(self, value, dialect): """Receive a bound parameter value to be converted. @@ -737,6 +776,40 @@ class TypeDecorator(TypeEngine): return self.__class__.process_bind_param.__code__ \ is not TypeDecorator.process_bind_param.__code__ + @util.memoized_property + def _has_literal_processor(self): + """memoized boolean, check if process_literal_param is implemented. + + + """ + + return self.__class__.process_literal_param.__code__ \ + is not TypeDecorator.process_literal_param.__code__ + + def literal_processor(self, dialect): + """Provide a literal processing function for the given + :class:`.Dialect`. + + Subclasses here will typically override :meth:`.TypeDecorator.process_literal_param` + instead of this method directly. + + .. versionadded:: 0.9.0 + + """ + if self._has_literal_processor: + process_param = self.process_literal_param + impl_processor = self.impl.literal_processor(dialect) + if impl_processor: + def process(value): + return impl_processor(process_param(value, dialect)) + else: + def process(value): + return process_param(value, dialect) + + return process + else: + return self.impl.literal_processor(dialect) + def bind_processor(self, dialect): """Provide a bound value processing function for the given :class:`.Dialect`. |
