summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2015-08-22 12:47:13 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2015-08-22 12:47:13 -0400
commit8712ef2f81498fe59b9636ba150833d779e60781 (patch)
tree8e42639b148196a22d930fdef851b44686886279
parente0a8030048f4ad0690d3084929441bda4c21aba2 (diff)
downloadsqlalchemy-8712ef2f81498fe59b9636ba150833d779e60781.tar.gz
- Added new checks for the common error case of passing mapped classes
or mapped instances into contexts where they are interpreted as SQL bound parameters; a new exception is raised for this. fixes #3321
-rw-r--r--doc/build/changelog/changelog_11.rst12
-rw-r--r--doc/build/changelog/migration_11.rst41
-rw-r--r--lib/sqlalchemy/sql/compiler.py2
-rw-r--r--lib/sqlalchemy/sql/default_comparator.py4
-rw-r--r--lib/sqlalchemy/sql/elements.py6
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py20
-rw-r--r--lib/sqlalchemy/sql/type_api.py3
-rw-r--r--lib/sqlalchemy/types.py1
-rw-r--r--test/aaa_profiling/test_compiler.py4
-rw-r--r--test/orm/test_query.py36
-rw-r--r--test/sql/test_types.py25
11 files changed, 139 insertions, 15 deletions
diff --git a/doc/build/changelog/changelog_11.rst b/doc/build/changelog/changelog_11.rst
index 0f974dc8c..695aa3c5c 100644
--- a/doc/build/changelog/changelog_11.rst
+++ b/doc/build/changelog/changelog_11.rst
@@ -22,6 +22,18 @@
:version: 1.1.0b1
.. change::
+ :tags: feature, orm
+ :tickets: 3321
+
+ Added new checks for the common error case of passing mapped classes
+ or mapped instances into contexts where they are interpreted as
+ SQL bound parameters; a new exception is raised for this.
+
+ .. seealso::
+
+ :ref:`change_3321`
+
+ .. change::
:tags: bug, postgresql
:tickets: 3499
diff --git a/doc/build/changelog/migration_11.rst b/doc/build/changelog/migration_11.rst
index c40d5a9c1..849d4516b 100644
--- a/doc/build/changelog/migration_11.rst
+++ b/doc/build/changelog/migration_11.rst
@@ -104,6 +104,47 @@ approach which applied a counter to the object.
:ticket:`3499`
+.. _change_3321:
+
+Specific checks added for passing mapped classes, instances as SQL literals
+---------------------------------------------------------------------------
+
+The typing system now has specific checks for passing of SQLAlchemy
+"inspectable" objects in contexts where they would otherwise be handled as
+literal values. Any SQLAlchemy built-in object that is legal to pass as a
+SQL value includes a method ``__clause_element__()`` which provides a
+valid SQL expression for that object. For SQLAlchemy objects that
+don't provide this, such as mapped classes, mappers, and mapped
+instances, a more informative error message is emitted rather than
+allowing the DBAPI to receive the object and fail later. An example
+is illustrated below, where a string-based attribute ``User.name`` is
+compared to a full instance of ``User()``, rather than against a
+string value::
+
+ >>> some_user = User()
+ >>> q = s.query(User).filter(User.name == some_user)
+ ...
+ sqlalchemy.exc.ArgumentError: Object <__main__.User object at 0x103167e90> is not legal as a SQL literal value
+
+The exception is now immediate when the comparison is made between
+``User.name == some_user``. Previously, a comparison like the above
+would produce a SQL expression that would only fail once resolved
+into a DBAPI execution call; the mapped ``User`` object would
+ultimately become a bound parameter that would be rejected by the
+DBAPI.
+
+Note that in the above example, the expression fails because
+``User.name`` is a string-based (e.g. column oriented) attribute.
+The change does *not* impact the usual case of comparing a many-to-one
+relationship attribute to an object, which is handled distinctly::
+
+ >>> # Address.user refers to the User mapper, so
+ >>> # this is of course still OK!
+ >>> q = s.query(Address).filter(Address.user == some_user)
+
+
+:ticket:`3321`
+
New Features and Improvements - Core
====================================
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index d3c46e643..4717b777f 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -281,6 +281,8 @@ class _CompileLabel(visitors.Visitable):
def type(self):
return self.element.type
+ def self_group(self, **kw):
+ return self
class SQLCompiler(Compiled):
diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py
index 09f639163..125fec33f 100644
--- a/lib/sqlalchemy/sql/default_comparator.py
+++ b/lib/sqlalchemy/sql/default_comparator.py
@@ -15,7 +15,7 @@ from .elements import BindParameter, True_, False_, BinaryExpression, \
Null, _const_expr, _clause_element_as_expr, \
ClauseList, ColumnElement, TextClause, UnaryExpression, \
collate, _is_literal, _literal_as_text, ClauseElement, and_, or_, \
- Slice
+ Slice, Visitable
from .selectable import SelectBase, Alias, Selectable, ScalarSelect
@@ -304,7 +304,7 @@ def _check_literal(expr, operator, other):
if isinstance(other, (SelectBase, Alias)):
return other.as_scalar()
- elif not isinstance(other, (ColumnElement, TextClause)):
+ elif not isinstance(other, Visitable):
return expr._bind_param(operator, other)
else:
return other
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 00c749b40..e2d81afc1 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -1145,8 +1145,7 @@ class BindParameter(ColumnElement):
_compared_to_type.coerce_compared_value(
_compared_to_operator, value)
else:
- self.type = type_api._type_map.get(type(value),
- type_api.NULLTYPE)
+ self.type = type_api._resolve_value_to_type(value)
elif isinstance(type_, type):
self.type = type_()
else:
@@ -1161,8 +1160,7 @@ class BindParameter(ColumnElement):
cloned.callable = None
cloned.required = False
if cloned.type is type_api.NULLTYPE:
- cloned.type = type_api._type_map.get(type(value),
- type_api.NULLTYPE)
+ cloned.type = type_api._resolve_value_to_type(value)
return cloned
@property
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index ec7dea300..b5c575143 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -9,7 +9,6 @@
"""
-import collections
import datetime as dt
import codecs
@@ -18,6 +17,7 @@ from .elements import quoted_name, type_coerce, _defer_name
from .. import exc, util, processors
from .base import _bind_or_error, SchemaEventTarget
from . import operators
+from .. import inspection
from .. import event
from ..util import pickle
import decimal
@@ -1736,6 +1736,21 @@ else:
_type_map[unicode] = Unicode()
_type_map[str] = String()
+_type_map_get = _type_map.get
+
+
+def _resolve_value_to_type(value):
+ _result_type = _type_map_get(type(value), False)
+ if _result_type is False:
+ # use inspect() to detect SQLAlchemy built-in
+ # objects.
+ insp = inspection.inspect(value, False)
+ if insp is not None:
+ raise exc.ArgumentError(
+ "Object %r is not legal as a SQL literal value" % value)
+ return NULLTYPE
+ else:
+ return _result_type
# back-assign to type_api
from . import type_api
@@ -1745,6 +1760,5 @@ type_api.INTEGERTYPE = INTEGERTYPE
type_api.NULLTYPE = NULLTYPE
type_api.MATCHTYPE = MATCHTYPE
type_api.INDEXABLE = Indexable
-type_api._type_map = _type_map
-
+type_api._resolve_value_to_type = _resolve_value_to_type
TypeEngine.Comparator.BOOLEANTYPE = BOOLEANTYPE
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index c4e830b7f..b9826e585 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -21,6 +21,7 @@ NULLTYPE = None
STRINGTYPE = None
MATCHTYPE = None
INDEXABLE = None
+_resolve_value_to_type = None
class TypeEngine(Visitable):
@@ -454,7 +455,7 @@ class TypeEngine(Visitable):
end-user customization of this behavior.
"""
- _coerced_type = _type_map.get(type(value), NULLTYPE)
+ _coerced_type = _resolve_value_to_type(value)
if _coerced_type is NULLTYPE or _coerced_type._type_affinity \
is self._type_affinity:
return self
diff --git a/lib/sqlalchemy/types.py b/lib/sqlalchemy/types.py
index 3a0e2a58f..61b89969f 100644
--- a/lib/sqlalchemy/types.py
+++ b/lib/sqlalchemy/types.py
@@ -76,5 +76,4 @@ from .sql.sqltypes import (
UnicodeText,
VARBINARY,
VARCHAR,
- _type_map
)
diff --git a/test/aaa_profiling/test_compiler.py b/test/aaa_profiling/test_compiler.py
index 5eece4602..5095be103 100644
--- a/test/aaa_profiling/test_compiler.py
+++ b/test/aaa_profiling/test_compiler.py
@@ -32,8 +32,8 @@ class CompileTest(fixtures.TestBase, AssertsExecutionResults):
for t in (t1, t2):
for c in t.c:
c.type._type_affinity
- from sqlalchemy import types
- for t in list(types._type_map.values()):
+ from sqlalchemy.sql import sqltypes
+ for t in list(sqltypes._type_map.values()):
t._type_affinity
cls.dialect = default.DefaultDialect()
diff --git a/test/orm/test_query.py b/test/orm/test_query.py
index 3ed2e7d7a..b0501739f 100644
--- a/test/orm/test_query.py
+++ b/test/orm/test_query.py
@@ -776,6 +776,42 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL):
meth, q, *arg, **kw
)
+ def test_illegal_coercions(self):
+ User = self.classes.User
+
+ assert_raises_message(
+ sa_exc.ArgumentError,
+ "Object .*User.* is not legal as a SQL literal value",
+ distinct, User
+ )
+
+ ua = aliased(User)
+ assert_raises_message(
+ sa_exc.ArgumentError,
+ "Object .*User.* is not legal as a SQL literal value",
+ distinct, ua
+ )
+
+ s = Session()
+ assert_raises_message(
+ sa_exc.ArgumentError,
+ "Object .*User.* is not legal as a SQL literal value",
+ lambda: s.query(User).filter(User.name == User)
+ )
+
+ u1 = User()
+ assert_raises_message(
+ sa_exc.ArgumentError,
+ "Object .*User.* is not legal as a SQL literal value",
+ distinct, u1
+ )
+
+ assert_raises_message(
+ sa_exc.ArgumentError,
+ "Object .*User.* is not legal as a SQL literal value",
+ lambda: s.query(User).filter(User.name == u1)
+ )
+
class OperatorTest(QueryTest, AssertsCompiledSQL):
"""test sql.Comparator implementation for MapperProperties"""
diff --git a/test/sql/test_types.py b/test/sql/test_types.py
index 0ab8ef451..90fac97c2 100644
--- a/test/sql/test_types.py
+++ b/test/sql/test_types.py
@@ -1,5 +1,6 @@
# coding: utf-8
-from sqlalchemy.testing import eq_, assert_raises, assert_raises_message, expect_warnings
+from sqlalchemy.testing import eq_, is_, assert_raises, \
+ assert_raises_message, expect_warnings
import decimal
import datetime
import os
@@ -11,7 +12,7 @@ from sqlalchemy import (
BLOB, NCHAR, NVARCHAR, CLOB, TIME, DATE, DATETIME, TIMESTAMP, SMALLINT,
INTEGER, DECIMAL, NUMERIC, FLOAT, REAL)
from sqlalchemy.sql import ddl
-
+from sqlalchemy import inspection
from sqlalchemy import exc, types, util, dialects
for name in dialects.__all__:
__import__("sqlalchemy.dialects.%s" % name)
@@ -1647,6 +1648,26 @@ class ExpressionTest(
assert distinct(test_table.c.data).type == test_table.c.data.type
assert test_table.c.data.distinct().type == test_table.c.data.type
+ def test_detect_coercion_of_builtins(self):
+ @inspection._self_inspects
+ class SomeSQLAThing(object):
+ def __repr__(self):
+ return "some_sqla_thing()"
+
+ class SomeOtherThing(object):
+ pass
+
+ assert_raises_message(
+ exc.ArgumentError,
+ r"Object some_sqla_thing\(\) is not legal as a SQL literal value",
+ lambda: column('a', String) == SomeSQLAThing()
+ )
+
+ is_(
+ bindparam('x', SomeOtherThing()).type,
+ types.NULLTYPE
+ )
+
class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
__dialect__ = 'default'