diff options
25 files changed, 802 insertions, 84 deletions
diff --git a/doc/build/changelog/migration_14.rst b/doc/build/changelog/migration_14.rst index 3a57191e3..63b59841f 100644 --- a/doc/build/changelog/migration_14.rst +++ b/doc/build/changelog/migration_14.rst @@ -756,6 +756,136 @@ the cascade settings for a viewonly relationship. :ticket:`4993` :ticket:`4994` +New Features - Core +==================== + +.. _change_4737: + + +Built-in FROM linting will warn for any potential cartesian products in a SELECT statement +------------------------------------------------------------------------------------------ + +As the Core expression language as well as the ORM are built on an "implicit +FROMs" model where a particular FROM clause is automatically added if any part +of the query refers to it, a common issue is the case where a SELECT statement, +either a top level statement or an embedded subquery, contains FROM elements +that are not joined to the rest of the FROM elements in the query, causing +what's referred to as a "cartesian product" in the result set, i.e. every +possible combination of rows from each FROM element not otherwise joined. In +relational databases, this is nearly always an undesirable outcome as it +produces an enormous result set full of duplicated, uncorrelated data. + +SQLAlchemy, for all of its great features, is particularly prone to this sort +of issue happening as a SELECT statement will have elements added to its FROM +clause automatically from any table seen in the other clauses. A typical +scenario looks like the following, where two tables are JOINed together, +however an additional entry in the WHERE clause that perhaps inadvertently does +not line up with these two tables will create an additional FROM entry:: + + address_alias = aliased(Address) + + q = session.query(User).\ + join(address_alias, User.addresses).\ + filter(Address.email_address == 'foo') + +The above query selects from a JOIN of ``User`` and ``address_alias``, the +latter of which is an alias of the ``Address`` entity. However, the +``Address`` entity is used within the WHERE clause directly, so the above would +result in the SQL:: + + SELECT + users.id AS users_id, users.name AS users_name, + users.fullname AS users_fullname, + users.nickname AS users_nickname + FROM addresses, users JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id + WHERE addresses.email_address = :email_address_1 + +In the above SQL, we can see what SQLAlchemy developers term "the dreaded +comma", as we see "FROM addresses, users JOIN addresses" in the FROM clause +which is the classic sign of a cartesian product; where a query is making use +of JOIN in order to join FROM clauses together, however because one of them is +not joined, it uses a comma. The above query will return a full set of +rows that join the "user" and "addresses" table together on the "id / user_id" +column, and will then apply all those rows into a cartesian product against +every row in the "addresses" table directly. That is, if there are ten user +rows and 100 rows in addresses, the above query will return its expected result +rows, likely to be 100 as all address rows would be selected, multiplied by 100 +again, so that the total result size would be 10000 rows. + +The "table1, table2 JOIN table3" pattern is one that also occurs quite +frequently within the SQLAlchemy ORM due to either subtle mis-application of +ORM features particularly those related to joined eager loading or joined table +inheritance, as well as a result of SQLAlchemy ORM bugs within those same +systems. Similar issues apply to SELECT statements that use "implicit joins", +where the JOIN keyword is not used and instead each FROM element is linked with +another one via the WHERE clause. + +For some years there has been a recipe on the Wiki that applies a graph +algorithm to a :func:`.select` construct at query execution time and inspects +the structure of the query for these un-linked FROM clauses, parsing through +the WHERE clause and all JOIN clauses to determine how FROM elements are linked +together and ensuring that all the FROM elements are connected in a single +graph. This recipe has now been adapted to be part of the :class:`.SQLCompiler` +itself where it now optionally emits a warning for a statement if this +condition is detected. The warning is enabled using the +:paramref:`.create_engine.enable_from_linting` flag and is enabled by default. +The computational overhead of the linter is very low, and additionally it only +occurs during statement compilation which means for a cached SQL statement it +only occurs once. + +Using this feature, our ORM query above will emit a warning:: + + >>> q.all() + SAWarning: SELECT statement has a cartesian product between FROM + element(s) "addresses_1", "users" and FROM element "addresses". + Apply join condition(s) between each element to resolve. + +The linter feature accommodates not just for tables linked together through the +JOIN clauses but also through the WHERE clause Above, we can add a WHERE +clause to link the new ``Address`` entity with the previous ``address_alias`` +entity and that will remove the warning:: + + q = session.query(User).\ + join(address_alias, User.addresses).\ + filter(Address.email_address == 'foo').\ + filter(Address.id == address_alias.id) # resolve cartesian products, + # will no longer warn + +The cartesian product warning considers **any** kind of link between two +FROM clauses to be a resolution, even if the end result set is still +wasteful, as the linter is intended only to detect the common case of a +FROM clause that is completely unexpected. If the FROM clause is referred +to explicitly elsewhere and linked to the other FROMs, no warning is emitted:: + + q = session.query(User).\ + join(address_alias, User.addresses).\ + filter(Address.email_address == 'foo').\ + filter(Address.id > address_alias.id) # will generate a lot of rows, + # but no warning + +Full cartesian products are also allowed if they are explicitly stated; if we +wanted for example the cartesian product of ``User`` and ``Address``, we can +JOIN on :func:`.true` so that every row will match with every other; the +following query will return all rows and produce no warnings:: + + from sqlalchemy import true + + # intentional cartesian product + q = session.query(User).join(Address, true()) # intentional cartesian product + +The warning is only generated by default when the statement is compiled by the +:class:`.Connection` for execution; calling the :meth:`.ClauseElement.compile` +method will not emit a warning unless the linting flag is supplied:: + + >>> from sqlalchemy.sql import FROM_LINTING + >>> print(q.statement.compile(linting=FROM_LINTING)) + SAWarning: SELECT statement has a cartesian product between FROM element(s) "addresses" and FROM element "users". Apply join condition(s) between each element to resolve. + SELECT users.id, users.name, users.fullname, users.nickname + FROM addresses, users JOIN addresses AS addresses_1 ON users.id = addresses_1.user_id + WHERE addresses.email_address = :email_address_1 + +:ticket:`4737` + Behavior Changes - Core diff --git a/doc/build/changelog/unreleased_14/4737.rst b/doc/build/changelog/unreleased_14/4737.rst new file mode 100644 index 000000000..072788ee8 --- /dev/null +++ b/doc/build/changelog/unreleased_14/4737.rst @@ -0,0 +1,18 @@ +.. change:: + :tags: feature,sql + :tickets: 4737 + + Added "from linting" as a built-in feature to the SQL compiler. This + allows the compiler to maintain graph of all the FROM clauses in a + particular SELECT statement, linked by criteria in either the WHERE + or in JOIN clauses that link these FROM clauses together. If any two + FROM clauses have no path between them, a warning is emitted that the + query may be producing a cartesian product. As the Core expression + language as well as the ORM are built on an "implicit FROMs" model where + a particular FROM clause is automatically added if any part of the query + refers to it, it is easy for this to happen inadvertently and it is + hoped that the new feature helps with this issue. + + .. seealso:: + + :ref:`change_4737` diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 8241d951b..6e84c9da1 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1434,7 +1434,10 @@ class MySQLCompiler(compiler.SQLCompiler): else: return "" - def visit_join(self, join, asfrom=False, **kwargs): + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + if from_linter: + from_linter.edges.add((join.left, join.right)) + if join.full: join_type = " FULL OUTER JOIN " elif join.isouter: @@ -1444,11 +1447,15 @@ class MySQLCompiler(compiler.SQLCompiler): return "".join( ( - self.process(join.left, asfrom=True, **kwargs), + self.process( + join.left, asfrom=True, from_linter=from_linter, **kwargs + ), join_type, - self.process(join.right, asfrom=True, **kwargs), + self.process( + join.right, asfrom=True, from_linter=from_linter, **kwargs + ), " ON ", - self.process(join.onclause, **kwargs), + self.process(join.onclause, from_linter=from_linter, **kwargs), ) ) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 9cb25b934..87e0baa58 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -829,19 +829,24 @@ class OracleCompiler(compiler.SQLCompiler): return " FROM DUAL" - def visit_join(self, join, **kwargs): + def visit_join(self, join, from_linter=None, **kwargs): if self.dialect.use_ansi: - return compiler.SQLCompiler.visit_join(self, join, **kwargs) + return compiler.SQLCompiler.visit_join( + self, join, from_linter=from_linter, **kwargs + ) else: + if from_linter: + from_linter.edges.add((join.left, join.right)) + kwargs["asfrom"] = True if isinstance(join.right, expression.FromGrouping): right = join.right.element else: right = join.right return ( - self.process(join.left, **kwargs) + self.process(join.left, from_linter=from_linter, **kwargs) + ", " - + self.process(right, **kwargs) + + self.process(right, from_linter=from_linter, **kwargs) ) def _get_nonansi_join_whereclause(self, froms): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 88558df5d..462e5f9ec 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -16,6 +16,7 @@ from .. import exc from .. import inspection from .. import log from .. import util +from ..sql import compiler from ..sql import schema from ..sql import util as sql_util @@ -1083,6 +1084,8 @@ class Connection(Connectable): schema_translate_map=self.schema_for_object if not self.schema_for_object.is_default else None, + linting=self.dialect.compiler_linting + | compiler.WARN_LINTING, ) self._execution_options["compiled_cache"][key] = compiled_sql else: @@ -1093,6 +1096,7 @@ class Connection(Connectable): schema_translate_map=self.schema_for_object if not self.schema_for_object.is_default else None, + linting=self.dialect.compiler_linting | compiler.WARN_LINTING, ) ret = self._execute_context( diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index 58fe91c7e..5198c8cd6 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -13,6 +13,7 @@ from .. import event from .. import exc from .. import pool as poollib from .. import util +from ..sql import compiler @util.deprecated_params( @@ -142,6 +143,16 @@ def create_engine(url, **kwargs): :param empty_in_strategy: No longer used; SQLAlchemy now uses "empty set" behavior for IN in all cases. + :param enable_from_linting: defaults to True. Will emit a warning + if a given SELECT statement is found to have un-linked FROM elements + which would cause a cartesian product. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`change_4737` + :param encoding: Defaults to ``utf-8``. This is the string encoding used by SQLAlchemy for string encode/decode operations which occur within SQLAlchemy, **outside of @@ -446,6 +457,11 @@ def create_engine(url, **kwargs): dialect_args["dbapi"] = dbapi + dialect_args.setdefault("compiler_linting", compiler.NO_LINTING) + enable_from_linting = kwargs.pop("enable_from_linting", True) + if enable_from_linting: + dialect_args["compiler_linting"] ^= compiler.COLLECT_CARTESIAN_PRODUCTS + for plugin in plugins: plugin.handle_dialect_kwargs(dialect_cls, dialect_args) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 1c995f05f..378890444 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -31,7 +31,6 @@ from ..sql import expression from ..sql import schema from ..sql.elements import quoted_name - AUTOCOMMIT_REGEXP = re.compile( r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)", re.I | re.UNICODE ) @@ -214,6 +213,9 @@ class DefaultDialect(interfaces.Dialect): supports_native_boolean=None, max_identifier_length=None, label_length=None, + # int() is because the @deprecated_params decorator cannot accommodate + # the direct reference to the "NO_LINTING" object + compiler_linting=int(compiler.NO_LINTING), **kwargs ): @@ -249,7 +251,7 @@ class DefaultDialect(interfaces.Dialect): self._user_defined_max_identifier_length ) self.label_length = label_length - + self.compiler_linting = compiler_linting if self.description_encoding == "use_encoding": self._description_decoder = ( processors.to_unicode_processor_factory diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 6554faaa0..488717041 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -5,6 +5,10 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +from .compiler import COLLECT_CARTESIAN_PRODUCTS # noqa +from .compiler import FROM_LINTING # noqa +from .compiler import NO_LINTING # noqa +from .compiler import WARN_LINTING # noqa from .expression import Alias # noqa from .expression import alias # noqa from .expression import all_ # noqa diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8499484f3..ed463ebe3 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -41,7 +41,6 @@ from .base import NO_ARG from .. import exc from .. import util - RESERVED_WORDS = set( [ "all", @@ -270,6 +269,89 @@ ExpandedState = collections.namedtuple( ) +NO_LINTING = util.symbol("NO_LINTING", "Disable all linting.", canonical=0) + +COLLECT_CARTESIAN_PRODUCTS = util.symbol( + "COLLECT_CARTESIAN_PRODUCTS", + "Collect data on FROMs and cartesian products and gather " + "into 'self.from_linter'", + canonical=1, +) + +WARN_LINTING = util.symbol( + "WARN_LINTING", "Emit warnings for linters that find problems", canonical=2 +) + +FROM_LINTING = util.symbol( + "FROM_LINTING", + "Warn for cartesian products; " + "combines COLLECT_CARTESIAN_PRODUCTS and WARN_LINTING", + canonical=COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING, +) + + +class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): + def lint(self, start=None): + froms = self.froms + if not froms: + return None, None + + edges = set(self.edges) + the_rest = set(froms) + + if start is not None: + start_with = start + the_rest.remove(start_with) + else: + start_with = the_rest.pop() + + stack = collections.deque([start_with]) + + while stack and the_rest: + node = stack.popleft() + the_rest.discard(node) + + # comparison of nodes in edges here is based on hash equality, as + # there are "annotated" elements that match the non-annotated ones. + # to remove the need for in-python hash() calls, use native + # containment routines (e.g. "node in edge", "edge.index(node)") + to_remove = {edge for edge in edges if node in edge} + + # appendleft the node in each edge that is not + # the one that matched. + stack.extendleft(edge[not edge.index(node)] for edge in to_remove) + edges.difference_update(to_remove) + + # FROMS left over? boom + if the_rest: + return the_rest, start_with + else: + return None, None + + def warn(self): + the_rest, start_with = self.lint() + + # FROMS left over? boom + if the_rest: + + froms = the_rest + if froms: + template = ( + "SELECT statement has a cartesian product between " + "FROM element(s) {froms} and " + 'FROM element "{start}". Apply join condition(s) ' + "between each element to resolve." + ) + froms_str = ", ".join( + '"{elem}"'.format(elem=self.froms[from_]) + for from_ in froms + ) + message = template.format( + froms=froms_str, start=self.froms[start_with] + ) + util.warn(message) + + class Compiled(object): """Represent a compiled SQL or DDL expression. @@ -568,7 +650,13 @@ class SQLCompiler(Compiled): insert_prefetch = update_prefetch = () def __init__( - self, dialect, statement, column_keys=None, inline=False, **kwargs + self, + dialect, + statement, + column_keys=None, + inline=False, + linting=NO_LINTING, + **kwargs ): """Construct a new :class:`.SQLCompiler` object. @@ -592,6 +680,8 @@ class SQLCompiler(Compiled): # execute) self.inline = inline or getattr(statement, "inline", False) + self.linting = linting + # a dictionary of bind parameter keys to BindParameter # instances. self.binds = {} @@ -1547,9 +1637,21 @@ class SQLCompiler(Compiled): return to_update, replacement_expression def visit_binary( - self, binary, override_operator=None, eager_grouping=False, **kw + self, + binary, + override_operator=None, + eager_grouping=False, + from_linter=None, + **kw ): + if from_linter and operators.is_comparison(binary.operator): + from_linter.edges.update( + itertools.product( + binary.left._from_objects, binary.right._from_objects + ) + ) + # don't allow "? = ?" to render if ( self.ansi_bind_rules @@ -1568,7 +1670,9 @@ class SQLCompiler(Compiled): except KeyError: raise exc.UnsupportedCompilationError(self, operator_) else: - return self._generate_generic_binary(binary, opstring, **kw) + return self._generate_generic_binary( + binary, opstring, from_linter=from_linter, **kw + ) def visit_function_as_comparison_op_binary(self, element, operator, **kw): return self.process(element.sql_function, **kw) @@ -1916,6 +2020,7 @@ class SQLCompiler(Compiled): ashint=False, fromhints=None, visiting_cte=None, + from_linter=None, **kwargs ): self._init_cte_state() @@ -2021,6 +2126,9 @@ class SQLCompiler(Compiled): self.ctes[cte] = text if asfrom: + if from_linter: + from_linter.froms[cte] = cte_name + if not is_new_cte and embedded_in_current_named_cte: return self.preparer.format_alias(cte, cte_name) @@ -2043,6 +2151,7 @@ class SQLCompiler(Compiled): subquery=False, lateral=False, enclosing_alias=None, + from_linter=None, **kwargs ): if enclosing_alias is not None and enclosing_alias.element is alias: @@ -2071,6 +2180,9 @@ class SQLCompiler(Compiled): if ashint: return self.preparer.format_alias(alias, alias_name) elif asfrom: + if from_linter: + from_linter.froms[alias] = alias_name + inner = alias.element._compiler_dispatch( self, asfrom=True, lateral=lateral, **kwargs ) @@ -2284,6 +2396,7 @@ class SQLCompiler(Compiled): compound_index=0, select_wraps_for=None, lateral=False, + from_linter=None, **kwargs ): @@ -2373,7 +2486,7 @@ class SQLCompiler(Compiled): ] text = self._compose_select_body( - text, select, inner_columns, froms, byfrom, kwargs + text, select, inner_columns, froms, byfrom, toplevel, kwargs ) if select._statement_hints: @@ -2465,10 +2578,17 @@ class SQLCompiler(Compiled): return froms def _compose_select_body( - self, text, select, inner_columns, froms, byfrom, kwargs + self, text, select, inner_columns, froms, byfrom, toplevel, kwargs ): text += ", ".join(inner_columns) + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + if froms: text += " \nFROM " @@ -2476,7 +2596,11 @@ class SQLCompiler(Compiled): text += ", ".join( [ f._compiler_dispatch( - self, asfrom=True, fromhints=byfrom, **kwargs + self, + asfrom=True, + fromhints=byfrom, + from_linter=from_linter, + **kwargs ) for f in froms ] @@ -2484,7 +2608,12 @@ class SQLCompiler(Compiled): else: text += ", ".join( [ - f._compiler_dispatch(self, asfrom=True, **kwargs) + f._compiler_dispatch( + self, + asfrom=True, + from_linter=from_linter, + **kwargs + ) for f in froms ] ) @@ -2492,10 +2621,18 @@ class SQLCompiler(Compiled): text += self.default_from() if select._whereclause is not None: - t = select._whereclause._compiler_dispatch(self, **kwargs) + t = select._whereclause._compiler_dispatch( + self, from_linter=from_linter, **kwargs + ) if t: text += " \nWHERE " + t + if ( + self.linting & COLLECT_CARTESIAN_PRODUCTS + and self.linting & WARN_LINTING + ): + from_linter.warn() + if select._group_by_clause.clauses: text += self.group_by_clause(select, **kwargs) @@ -2597,8 +2734,12 @@ class SQLCompiler(Compiled): ashint=False, fromhints=None, use_schema=True, + from_linter=None, **kwargs ): + if from_linter: + from_linter.froms[table] = table.fullname + if asfrom or ashint: effective_schema = self.preparer.schema_for_object(table) @@ -2618,7 +2759,10 @@ class SQLCompiler(Compiled): else: return "" - def visit_join(self, join, asfrom=False, **kwargs): + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + if from_linter: + from_linter.edges.add((join.left, join.right)) + if join.full: join_type = " FULL OUTER JOIN " elif join.isouter: @@ -2626,12 +2770,18 @@ class SQLCompiler(Compiled): else: join_type = " JOIN " return ( - join.left._compiler_dispatch(self, asfrom=True, **kwargs) + join.left._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + join_type - + join.right._compiler_dispatch(self, asfrom=True, **kwargs) + + join.right._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + " ON " # TODO: likely need asfrom=True here? - + join.onclause._compiler_dispatch(self, **kwargs) + + join.onclause._compiler_dispatch( + self, from_linter=from_linter, **kwargs + ) ) def _setup_crud_hints(self, stmt, table_text): diff --git a/test/base/test_tutorials.py b/test/base/test_tutorials.py index b8322db0a..2c1058b9a 100644 --- a/test/base/test_tutorials.py +++ b/test/base/test_tutorials.py @@ -6,6 +6,7 @@ import os import re import sys +from sqlalchemy import testing from sqlalchemy.testing import config from sqlalchemy.testing import fixtures @@ -86,6 +87,7 @@ class DocTest(fixtures.TestBase): def test_orm(self): self._run_doctest("orm/tutorial.rst") + @testing.emits_warning("SELECT statement has a cartesian") def test_core(self): self._run_doctest("core/tutorial.rst") diff --git a/test/orm/inheritance/test_abc_inheritance.py b/test/orm/inheritance/test_abc_inheritance.py index e2c3a3aa0..b5bff78d2 100644 --- a/test/orm/inheritance/test_abc_inheritance.py +++ b/test/orm/inheritance/test_abc_inheritance.py @@ -210,6 +210,7 @@ def produce_test(parent, child, direction): C, tc, polymorphic_identity="c", + with_polymorphic=("*", tc.join(tb, btoc).join(ta, atob)), inherits=B, inherit_condition=btoc, ) diff --git a/test/orm/inheritance/test_assorted_poly.py b/test/orm/inheritance/test_assorted_poly.py index ecab0a497..616ed79a6 100644 --- a/test/orm/inheritance/test_assorted_poly.py +++ b/test/orm/inheritance/test_assorted_poly.py @@ -597,12 +597,14 @@ class RelationshipTest4(fixtures.MappedTest): mapper( Engineer, engineers, + with_polymorphic=([Engineer], people.join(engineers)), inherits=person_mapper, polymorphic_identity="engineer", ) mapper( Manager, managers, + with_polymorphic=([Manager], people.join(managers)), inherits=person_mapper, polymorphic_identity="manager", ) @@ -1239,12 +1241,14 @@ class GenerativeTest(fixtures.MappedTest, AssertsExecutionResults): mapper( Engineer, engineers, + with_polymorphic=([Engineer], people.join(engineers)), inherits=person_mapper, polymorphic_identity="engineer", ) mapper( Manager, managers, + with_polymorphic=([Manager], people.join(managers)), inherits=person_mapper, polymorphic_identity="manager", ) diff --git a/test/orm/inheritance/test_poly_persistence.py b/test/orm/inheritance/test_poly_persistence.py index 508cb9965..deb33838f 100644 --- a/test/orm/inheritance/test_poly_persistence.py +++ b/test/orm/inheritance/test_poly_persistence.py @@ -502,12 +502,15 @@ class RoundTripTest(PolymorphTest): session = Session() dilbert = get_dilbert(session) + # this unusual test is selecting from the plain people/engineers + # table at the same time as the polymorphic entity is_( dilbert, session.query(Person) .filter( (Engineer.engineer_name == "engineer1") & (engineers.c.person_id == people.c.person_id) + & (people.c.person_id == Person.person_id) ) .first(), ) diff --git a/test/orm/inheritance/test_polymorphic_rel.py b/test/orm/inheritance/test_polymorphic_rel.py index b188e598d..b376be12a 100644 --- a/test/orm/inheritance/test_polymorphic_rel.py +++ b/test/orm/inheritance/test_polymorphic_rel.py @@ -3,9 +3,11 @@ from sqlalchemy import exc as sa_exc from sqlalchemy import func from sqlalchemy import select from sqlalchemy import testing +from sqlalchemy import true from sqlalchemy.orm import aliased from sqlalchemy.orm import create_session from sqlalchemy.orm import defaultload +from sqlalchemy.orm import join from sqlalchemy.orm import joinedload from sqlalchemy.orm import subqueryload from sqlalchemy.orm import with_polymorphic @@ -174,6 +176,7 @@ class _PolymorphicTestBase(object): sess.query(Company, Person, c, e) .join(Person, Company.employees) .join(e, c.employees) + .filter(Person.person_id != e.person_id) .filter(Person.name == "dilbert") .filter(e.name == "wally") ) @@ -897,15 +900,28 @@ class _PolymorphicTestBase(object): ] def go(): + wp = with_polymorphic(Person, "*") eq_( - sess.query(Person) - .with_polymorphic("*") - .options(subqueryload(Engineer.machines)) - .filter(Person.name == "dilbert") + sess.query(wp) + .options(subqueryload(wp.Engineer.machines)) + .filter(wp.name == "dilbert") .all(), expected, ) + # the old version of this test has never worked, apparently, + # was always spitting out a cartesian product. Since we + # are getting rid of query.with_polymorphic() is it not + # worth fixing. + # eq_( + # sess.query(Person) + # .with_polymorphic("*") + # .options(subqueryload(Engineer.machines)) + # .filter(Person.name == "dilbert") + # .all(), + # expected, + # ) + self.assert_sql_count(testing.db, go, 2) def test_query_subclass_join_to_base_relationship(self): @@ -1393,6 +1409,7 @@ class _PolymorphicTestBase(object): .join(Company.employees) .filter(Company.name == "Elbonia, Inc.") .filter(palias.name == "dilbert") + .filter(palias.person_id != Person.person_id) .all(), expected, ) @@ -1420,8 +1437,10 @@ class _PolymorphicTestBase(object): ), ) ] + eq_( sess.query(palias, Company.name, Person) + .select_from(join(palias, Company, true())) .join(Company.employees) .filter(Company.name == "Elbonia, Inc.") .filter(palias.name == "dilbert") @@ -1438,6 +1457,7 @@ class _PolymorphicTestBase(object): .join(Company.employees) .filter(Company.name == "Elbonia, Inc.") .filter(palias.name == "dilbert") + .filter(palias.company_id != Person.company_id) .all(), expected, ) diff --git a/test/orm/inheritance/test_relationship.py b/test/orm/inheritance/test_relationship.py index b5a0fabec..94ab0a994 100644 --- a/test/orm/inheritance/test_relationship.py +++ b/test/orm/inheritance/test_relationship.py @@ -11,6 +11,7 @@ from sqlalchemy.orm import create_session from sqlalchemy.orm import joinedload from sqlalchemy.orm import mapper from sqlalchemy.orm import relationship +from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from sqlalchemy.orm import subqueryload from sqlalchemy.orm import with_polymorphic @@ -1302,10 +1303,12 @@ class SameNamedPropTwoPolymorphicSubClassesTest(fixtures.MappedTest): d = session.query(D).one() def go(): + # NOTE: subqueryload is broken for this case, first found + # when cartesian product detection was added. for a in ( session.query(A) .with_polymorphic([B, C]) - .options(subqueryload(B.related), subqueryload(C.related)) + .options(selectinload(B.related), selectinload(C.related)) ): eq_(a.related, [d]) diff --git a/test/orm/inheritance/test_single.py b/test/orm/inheritance/test_single.py index b070e2848..ed9c781f4 100644 --- a/test/orm/inheritance/test_single.py +++ b/test/orm/inheritance/test_single.py @@ -8,6 +8,7 @@ from sqlalchemy import null from sqlalchemy import select from sqlalchemy import String from sqlalchemy import testing +from sqlalchemy import true from sqlalchemy.orm import aliased from sqlalchemy.orm import Bundle from sqlalchemy.orm import class_mapper @@ -143,29 +144,40 @@ class SingleInheritanceTest(testing.AssertsCompiledSQL, fixtures.MappedTest): session, m1, e1, e2 = self._fixture_one() ealias = aliased(Engineer) - eq_(session.query(Manager, ealias).all(), [(m1, e1), (m1, e2)]) + eq_( + session.query(Manager, ealias).join(ealias, true()).all(), + [(m1, e1), (m1, e2)], + ) eq_(session.query(Manager.name).all(), [("Tom",)]) eq_( - session.query(Manager.name, ealias.name).all(), + session.query(Manager.name, ealias.name) + .join(ealias, true()) + .all(), [("Tom", "Kurt"), ("Tom", "Ed")], ) eq_( - session.query( - func.upper(Manager.name), func.upper(ealias.name) - ).all(), + session.query(func.upper(Manager.name), func.upper(ealias.name)) + .join(ealias, true()) + .all(), [("TOM", "KURT"), ("TOM", "ED")], ) eq_( - session.query(Manager).add_entity(ealias).all(), + session.query(Manager) + .add_entity(ealias) + .join(ealias, true()) + .all(), [(m1, e1), (m1, e2)], ) eq_( - session.query(Manager.name).add_column(ealias.name).all(), + session.query(Manager.name) + .add_column(ealias.name) + .join(ealias, true()) + .all(), [("Tom", "Kurt"), ("Tom", "Ed")], ) diff --git a/test/orm/test_deprecations.py b/test/orm/test_deprecations.py index 17954f308..d5a46e9ea 100644 --- a/test/orm/test_deprecations.py +++ b/test/orm/test_deprecations.py @@ -423,10 +423,8 @@ class DeprecatedQueryTest(_fixtures.FixtureTest, AssertsCompiledSQL): with self._expect_implicit_subquery(): eq_( s.query(User) - .select_from( - text("select * from users").columns( - id=Integer, name=String - ) + .select_entity_from( + text("select * from users").columns(User.id, User.name) ) .order_by(User.id) .all(), diff --git a/test/orm/test_eager_relations.py b/test/orm/test_eager_relations.py index 659f6e103..bf39b25a6 100644 --- a/test/orm/test_eager_relations.py +++ b/test/orm/test_eager_relations.py @@ -5583,6 +5583,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): s.query(aa, A) .filter(aa.id == 1) .filter(A.id == 2) + .filter(aa.id != A.id) .options(joinedload("bs").joinedload("cs")) ) self._run_tests(q, 1) @@ -5595,6 +5596,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): s.query(A, aa) .filter(aa.id == 2) .filter(A.id == 1) + .filter(aa.id != A.id) .options(joinedload("bs").joinedload("cs")) ) self._run_tests(q, 1) @@ -5607,6 +5609,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): s.query(aa, A) .filter(aa.id == 1) .filter(A.id == 2) + .filter(aa.id != A.id) .options(joinedload(A.bs).joinedload(B.cs)) ) self._run_tests(q, 3) @@ -5619,6 +5622,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): s.query(aa, A) .filter(aa.id == 1) .filter(A.id == 2) + .filter(aa.id != A.id) .options(defaultload(A.bs).joinedload(B.cs)) ) self._run_tests(q, 3) @@ -5629,7 +5633,13 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): aa = aliased(A) opt = Load(A).joinedload(A.bs).joinedload(B.cs) - q = s.query(aa, A).filter(aa.id == 1).filter(A.id == 2).options(opt) + q = ( + s.query(aa, A) + .filter(aa.id == 1) + .filter(A.id == 2) + .filter(aa.id != A.id) + .options(opt) + ) self._run_tests(q, 3) def test_pathed_lazyload_plus_joined_aliased_abs_bcs(self): @@ -5638,7 +5648,13 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): aa = aliased(A) opt = Load(aa).defaultload(aa.bs).joinedload(B.cs) - q = s.query(aa, A).filter(aa.id == 1).filter(A.id == 2).options(opt) + q = ( + s.query(aa, A) + .filter(aa.id == 1) + .filter(A.id == 2) + .filter(aa.id != A.id) + .options(opt) + ) self._run_tests(q, 2) def test_pathed_joinedload_aliased_abs_bcs(self): @@ -5647,7 +5663,13 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): aa = aliased(A) opt = Load(aa).joinedload(aa.bs).joinedload(B.cs) - q = s.query(aa, A).filter(aa.id == 1).filter(A.id == 2).options(opt) + q = ( + s.query(aa, A) + .filter(aa.id == 1) + .filter(A.id == 2) + .filter(aa.id != A.id) + .options(opt) + ) self._run_tests(q, 1) def test_lazyload_plus_joined_aliased_abs_bcs(self): @@ -5658,6 +5680,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): s.query(aa, A) .filter(aa.id == 1) .filter(A.id == 2) + .filter(aa.id != A.id) .options(defaultload(aa.bs).joinedload(B.cs)) ) self._run_tests(q, 2) @@ -5670,6 +5693,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): s.query(aa, A) .filter(aa.id == 1) .filter(A.id == 2) + .filter(aa.id != A.id) .options(joinedload(aa.bs).joinedload(B.cs)) ) self._run_tests(q, 1) @@ -5682,6 +5706,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): s.query(A, aa) .filter(aa.id == 2) .filter(A.id == 1) + .filter(aa.id != A.id) .options(joinedload(aa.bs).joinedload(B.cs)) ) self._run_tests(q, 3) @@ -5694,6 +5719,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): s.query(A, aa) .filter(aa.id == 2) .filter(A.id == 1) + .filter(aa.id != A.id) .options(defaultload(aa.bs).joinedload(B.cs)) ) self._run_tests(q, 3) @@ -5706,6 +5732,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): s.query(A, aa) .filter(aa.id == 2) .filter(A.id == 1) + .filter(aa.id != A.id) .options(defaultload(A.bs).joinedload(B.cs)) ) self._run_tests(q, 2) @@ -5718,6 +5745,7 @@ class LazyLoadOptSpecificityTest(fixtures.DeclarativeMappedTest): s.query(A, aa) .filter(aa.id == 2) .filter(A.id == 1) + .filter(aa.id != A.id) .options(joinedload(A.bs).joinedload(B.cs)) ) self._run_tests(q, 1) diff --git a/test/orm/test_froms.py b/test/orm/test_froms.py index 54714864d..7195f53cb 100644 --- a/test/orm/test_froms.py +++ b/test/orm/test_froms.py @@ -9,11 +9,13 @@ from sqlalchemy import ForeignKey from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import literal_column +from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table from sqlalchemy import testing from sqlalchemy import text +from sqlalchemy import true from sqlalchemy import util from sqlalchemy.engine import default from sqlalchemy.orm import aliased @@ -1481,6 +1483,7 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): q2 = ( q.select_entity_from(sel) .filter(u2.id > 1) + .filter(or_(u2.id == User.id, u2.id != User.id)) .order_by(User.id, sel.c.id, u2.id) .values(User.name, sel.c.name, u2.name) ) @@ -1853,17 +1856,17 @@ class MixedEntitiesTest(QueryTest, AssertsCompiledSQL): .filter(Order.id > oalias.id) .order_by(Order.id, oalias.id), sess.query(Order, oalias) + .filter(Order.id > oalias.id) .from_self() .filter(Order.user_id == oalias.user_id) .filter(Order.user_id == 7) - .filter(Order.id > oalias.id) .order_by(Order.id, oalias.id), # same thing, but reversed. sess.query(oalias, Order) + .filter(Order.id < oalias.id) .from_self() .filter(oalias.user_id == Order.user_id) .filter(oalias.user_id == 7) - .filter(Order.id < oalias.id) .order_by(oalias.id, Order.id), # here we go....two layers of aliasing sess.query(Order, oalias) @@ -3537,7 +3540,11 @@ class LabelCollideTest(fixtures.MappedTest): def test_overlap_plain(self): s = Session() - row = s.query(self.classes.Foo, self.classes.Bar).all()[0] + row = ( + s.query(self.classes.Foo, self.classes.Bar) + .join(self.classes.Bar, true()) + .all()[0] + ) def go(): eq_(row.Foo.id, 1) @@ -3550,7 +3557,12 @@ class LabelCollideTest(fixtures.MappedTest): def test_overlap_subquery(self): s = Session() - row = s.query(self.classes.Foo, self.classes.Bar).from_self().all()[0] + row = ( + s.query(self.classes.Foo, self.classes.Bar) + .join(self.classes.Bar, true()) + .from_self() + .all()[0] + ) def go(): eq_(row.Foo.id, 1) diff --git a/test/orm/test_pickled.py b/test/orm/test_pickled.py index 72a68c42b..4da583a0c 100644 --- a/test/orm/test_pickled.py +++ b/test/orm/test_pickled.py @@ -777,7 +777,11 @@ class TupleLabelTest(_fixtures.FixtureTest): eq_(row.foobar, row[1]) oalias = aliased(Order) - for row in sess.query(User, oalias).join(User.orders).all(): + for row in ( + sess.query(User, oalias) + .join(User.orders.of_type(oalias)) + .all() + ): if pickled is not False: row = pickle.loads(pickle.dumps(row, pickled)) eq_(list(row.keys()), ["User"]) diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 271d85dd6..55809ad38 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -28,6 +28,7 @@ from sqlalchemy import String from sqlalchemy import table from sqlalchemy import testing from sqlalchemy import text +from sqlalchemy import true from sqlalchemy import Unicode from sqlalchemy import union from sqlalchemy import util @@ -3783,7 +3784,7 @@ class CountTest(QueryTest): User, Address = self.classes.User, self.classes.Address s = create_session() - q = s.query(User, Address) + q = s.query(User, Address).join(Address, true()) eq_(q.count(), 20) # cartesian product q = s.query(User, Address).join(User.addresses) @@ -3793,10 +3794,10 @@ class CountTest(QueryTest): User, Address = self.classes.User, self.classes.Address s = create_session() - q = s.query(User, Address).limit(2) + q = s.query(User, Address).join(Address, true()).limit(2) eq_(q.count(), 2) - q = s.query(User, Address).limit(100) + q = s.query(User, Address).join(Address, true()).limit(100) eq_(q.count(), 20) q = s.query(User, Address).join(User.addresses).limit(100) @@ -3818,7 +3819,7 @@ class CountTest(QueryTest): q = s.query(User.name) eq_(q.count(), 4) - q = s.query(User.name, Address) + q = s.query(User.name, Address).join(Address, true()) eq_(q.count(), 20) q = s.query(Address.user_id) @@ -3888,7 +3889,9 @@ class DistinctTest(QueryTest, AssertsCompiledSQL): q = ( sess.query(User.id, User.name.label("foo"), Address.id) + .join(Address, true()) .filter(User.name == "jack") + .filter(User.id + Address.user_id > 0) .distinct() .order_by(User.id, User.name, Address.email_address) ) @@ -4541,9 +4544,9 @@ class TextTest(QueryTest, AssertsCompiledSQL): eq_( s.query(User) - .select_from( + .select_entity_from( text("select * from users") - .columns(id=Integer, name=String) + .columns(User.id, User.name) .subquery() ) .order_by(User.id) diff --git a/test/profiles.txt b/test/profiles.txt index 5e2ca814a..41f1fab7a 100644 --- a/test/profiles.txt +++ b/test/profiles.txt @@ -1,15 +1,15 @@ # /home/classic/dev/sqlalchemy/test/profiles.txt # This file is written out on a per-environment basis. -# For each test in aaa_profiling, the corresponding function and +# For each test in aaa_profiling, the corresponding function and # environment is located within this file. If it doesn't exist, # the test is skipped. -# If a callcount does exist, it is compared to what we received. +# If a callcount does exist, it is compared to what we received. # assertions are raised if the counts do not match. -# -# To add a new callcount test, apply the function_call_count -# decorator and re-run the tests using the --write-profiles +# +# To add a new callcount test, apply the function_call_count +# decorator and re-run the tests using the --write-profiles # option - this file will be rewritten including the new count. -# +# # TEST: test.aaa_profiling.test_compiler.CompileTest.test_insert @@ -1027,14 +1027,10 @@ test.aaa_profiling.test_resultset.ResultSetTest.test_unicode 3.7_sqlite_pysqlite # TEST: test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation -test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_cextensions 6412,322,4242,12454,1244,2187,2770 -test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 6429,322,4242,13149,1341,2187,2766 -test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_cextensions 6177,306,4162,12597,1233,2133,2650 -test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 6260,306,4242,13203,1344,2151,2840 +test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_cextensions 6177,306,4162,12597,1233,2133,2852 +test.aaa_profiling.test_zoomark.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 6260,306,4242,13203,1344,2151,3046 # TEST: test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation -test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_cextensions 7085,420,7205,18682,1257,2846 -test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 2.7_postgresql_psycopg2_dbapiunicode_nocextensions 7130,423,7229,20001,1346,2864 test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_cextensions 7090,411,7281,19190,1247,2897 test.aaa_profiling.test_zoomark_orm.ZooMarkTest.test_invocation 3.7_postgresql_psycopg2_dbapiunicode_nocextensions 7186,416,7465,20675,1350,2957 diff --git a/test/sql/test_defaults.py b/test/sql/test_defaults.py index f033abab2..b31b070d8 100644 --- a/test/sql/test_defaults.py +++ b/test/sql/test_defaults.py @@ -852,7 +852,7 @@ class CTEDefaultTest(fixtures.TablesTest): if b == "select": conn.execute(p.insert().values(s=1)) - stmt = select([p.c.s, cte.c.z]) + stmt = select([p.c.s, cte.c.z]).where(p.c.s == cte.c.z) elif b == "insert": sel = select([1, cte.c.z]) stmt = ( diff --git a/test/sql/test_from_linter.py b/test/sql/test_from_linter.py new file mode 100644 index 000000000..bf2f06b57 --- /dev/null +++ b/test/sql/test_from_linter.py @@ -0,0 +1,277 @@ +from sqlalchemy import Integer +from sqlalchemy import select +from sqlalchemy import sql +from sqlalchemy import true +from sqlalchemy.testing import config +from sqlalchemy.testing import engines +from sqlalchemy.testing import expect_warnings +from sqlalchemy.testing import fixtures +from sqlalchemy.testing import is_ +from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import Table + + +def find_unmatching_froms(query, start=None): + compiled = query.compile(linting=sql.COLLECT_CARTESIAN_PRODUCTS) + + return compiled.from_linter.lint(start) + + +class TestFindUnmatchingFroms(fixtures.TablesTest): + @classmethod + def define_tables(cls, metadata): + Table("table_a", metadata, Column("col_a", Integer, primary_key=True)) + Table("table_b", metadata, Column("col_b", Integer, primary_key=True)) + Table("table_c", metadata, Column("col_c", Integer, primary_key=True)) + Table("table_d", metadata, Column("col_d", Integer, primary_key=True)) + + def setup(self): + self.a = self.tables.table_a + self.b = self.tables.table_b + self.c = self.tables.table_c + self.d = self.tables.table_d + + def test_everything_is_connected(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b)) + .select_from(self.c) + .select_from(self.d) + .where(self.d.c.col_d == self.b.c.col_b) + .where(self.c.c.col_c == self.d.c.col_d) + .where(self.c.c.col_c == 5) + ) + froms, start = find_unmatching_froms(query) + assert not froms + + for start in self.a, self.b, self.c, self.d: + froms, start = find_unmatching_froms(query, start) + assert not froms + + def test_plain_cartesian(self): + query = select([self.a]).where(self.b.c.col_b == 5) + froms, start = find_unmatching_froms(query, self.a) + assert start == self.a + assert froms == {self.b} + + froms, start = find_unmatching_froms(query, self.b) + assert start == self.b + assert froms == {self.a} + + def test_count_non_eq_comparison_operators(self): + query = select([self.a]).where(self.a.c.col_a > self.b.c.col_b) + froms, start = find_unmatching_froms(query, self.a) + is_(start, None) + is_(froms, None) + + def test_dont_count_non_comparison_operators(self): + query = select([self.a]).where(self.a.c.col_a + self.b.c.col_b == 5) + froms, start = find_unmatching_froms(query, self.a) + assert start == self.a + assert froms == {self.b} + + def test_disconnect_between_ab_cd(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b)) + .select_from(self.c) + .select_from(self.d) + .where(self.c.c.col_c == self.d.c.col_d) + .where(self.c.c.col_c == 5) + ) + for start in self.a, self.b: + froms, start = find_unmatching_froms(query, start) + assert start == start + assert froms == {self.c, self.d} + for start in self.c, self.d: + froms, start = find_unmatching_froms(query, start) + assert start == start + assert froms == {self.a, self.b} + + def test_c_and_d_both_disconnected(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b)) + .where(self.c.c.col_c == 5) + .where(self.d.c.col_d == 10) + ) + for start in self.a, self.b: + froms, start = find_unmatching_froms(query, start) + assert start == start + assert froms == {self.c, self.d} + + froms, start = find_unmatching_froms(query, self.c) + assert start == self.c + assert froms == {self.a, self.b, self.d} + + froms, start = find_unmatching_froms(query, self.d) + assert start == self.d + assert froms == {self.a, self.b, self.c} + + def test_now_connected(self): + query = ( + select([self.a]) + .select_from(self.a.join(self.b, self.a.c.col_a == self.b.c.col_b)) + .select_from(self.c.join(self.d, self.c.c.col_c == self.d.c.col_d)) + .where(self.c.c.col_c == self.b.c.col_b) + .where(self.c.c.col_c == 5) + .where(self.d.c.col_d == 10) + ) + froms, start = find_unmatching_froms(query) + assert not froms + + for start in self.a, self.b, self.c, self.d: + froms, start = find_unmatching_froms(query, start) + assert not froms + + def test_disconnected_subquery(self): + subq = ( + select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery() + ) + stmt = select([self.c]).select_from(subq) + + froms, start = find_unmatching_froms(stmt, self.c) + assert start == self.c + assert froms == {subq} + + froms, start = find_unmatching_froms(stmt, subq) + assert start == subq + assert froms == {self.c} + + def test_now_connect_it(self): + subq = ( + select([self.a]).where(self.a.c.col_a == self.b.c.col_b).subquery() + ) + stmt = ( + select([self.c]) + .select_from(subq) + .where(self.c.c.col_c == subq.c.col_a) + ) + + froms, start = find_unmatching_froms(stmt) + assert not froms + + for start in self.c, subq: + froms, start = find_unmatching_froms(stmt, start) + assert not froms + + def test_right_nested_join_without_issue(self): + query = select([self.a]).select_from( + self.a.join( + self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), + self.a.c.col_a == self.b.c.col_b, + ) + ) + froms, start = find_unmatching_froms(query) + assert not froms + + for start in self.a, self.b, self.c: + froms, start = find_unmatching_froms(query, start) + assert not froms + + def test_join_on_true(self): + # test that a join(a, b) counts a->b as an edge even if there isn't + # actually a join condition. this essentially allows a cartesian + # product to be added explicitly. + + query = select([self.a]).select_from(self.a.join(self.b, true())) + froms, start = find_unmatching_froms(query) + assert not froms + + def test_right_nested_join_with_an_issue(self): + query = ( + select([self.a]) + .select_from( + self.a.join( + self.b.join(self.c, self.b.c.col_b == self.c.c.col_c), + self.a.c.col_a == self.b.c.col_b, + ) + ) + .where(self.d.c.col_d == 5) + ) + + for start in self.a, self.b, self.c: + froms, start = find_unmatching_froms(query, start) + assert start == start + assert froms == {self.d} + + froms, start = find_unmatching_froms(query, self.d) + assert start == self.d + assert froms == {self.a, self.b, self.c} + + def test_no_froms(self): + query = select([1]) + + froms, start = find_unmatching_froms(query) + assert not froms + + +class TestLinter(fixtures.TablesTest): + @classmethod + def define_tables(cls, metadata): + Table("table_a", metadata, Column("col_a", Integer, primary_key=True)) + Table("table_b", metadata, Column("col_b", Integer, primary_key=True)) + + @classmethod + def setup_bind(cls): + # from linting is enabled by default + return config.db + + def test_noop_for_unhandled_objects(self): + with self.bind.connect() as conn: + conn.execute("SELECT 1;").fetchone() + + def test_does_not_modify_query(self): + with self.bind.connect() as conn: + [result] = conn.execute(select([1])).fetchone() + assert result == 1 + + def test_warn_simple(self): + a, b = self.tables("table_a", "table_b") + query = select([a.c.col_a]).where(b.c.col_b == 5) + + with expect_warnings( + r"SELECT statement has a cartesian product between FROM " + r'element\(s\) "table_[ab]" ' + r'and FROM element "table_[ba]"' + ): + with self.bind.connect() as conn: + conn.execute(query) + + def test_warn_anon_alias(self): + a, b = self.tables("table_a", "table_b") + + b_alias = b.alias() + query = select([a.c.col_a]).where(b_alias.c.col_b == 5) + + with expect_warnings( + r"SELECT statement has a cartesian product between FROM " + r'element\(s\) "table_(?:a|b_1)" ' + r'and FROM element "table_(?:a|b_1)"' + ): + with self.bind.connect() as conn: + conn.execute(query) + + def test_warn_anon_cte(self): + a, b = self.tables("table_a", "table_b") + + b_cte = select([b]).cte() + query = select([a.c.col_a]).where(b_cte.c.col_b == 5) + + with expect_warnings( + r"SELECT statement has a cartesian product between " + r"FROM element\(s\) " + r'"(?:anon_1|table_a)" ' + r'and FROM element "(?:anon_1|table_a)"' + ): + with self.bind.connect() as conn: + conn.execute(query) + + def test_no_linting(self): + eng = engines.testing_engine(options={"enable_from_linting": False}) + eng.pool = self.bind.pool # needed for SQLite + a, b = self.tables("table_a", "table_b") + query = select([a.c.col_a]).where(b.c.col_b == 5) + + with eng.connect() as conn: + conn.execute(query) diff --git a/test/sql/test_resultset.py b/test/sql/test_resultset.py index 794508a32..8aa524d78 100644 --- a/test/sql/test_resultset.py +++ b/test/sql/test_resultset.py @@ -19,6 +19,7 @@ from sqlalchemy import String from sqlalchemy import table from sqlalchemy import testing from sqlalchemy import text +from sqlalchemy import true from sqlalchemy import type_coerce from sqlalchemy import TypeDecorator from sqlalchemy import util @@ -771,7 +772,11 @@ class ResultProxyTest(fixtures.TablesTest): users.insert().execute(user_id=1, user_name="john") ua = users.alias() u2 = users.alias() - result = select([users.c.user_id, ua.c.user_id]).execute() + result = ( + select([users.c.user_id, ua.c.user_id]) + .select_from(users.join(ua, true())) + .execute() + ) row = result.first() # as of 1.1 issue #3501, we use pure positional @@ -1414,7 +1419,9 @@ class KeyTargetingTest(fixtures.TablesTest): keyed1 = self.tables.keyed1 keyed2 = self.tables.keyed2 - row = testing.db.execute(select([keyed1, keyed2])).first() + row = testing.db.execute( + select([keyed1, keyed2]).select_from(keyed1.join(keyed2, true())) + ).first() # column access is unambiguous eq_(row[self.tables.keyed2.c.b], "b2") @@ -1446,7 +1453,9 @@ class KeyTargetingTest(fixtures.TablesTest): keyed2 = self.tables.keyed2 row = testing.db.execute( - select([keyed1, keyed2]).apply_labels() + select([keyed1, keyed2]) + .select_from(keyed1.join(keyed2, true())) + .apply_labels() ).first() # column access is unambiguous @@ -1459,7 +1468,9 @@ class KeyTargetingTest(fixtures.TablesTest): keyed1 = self.tables.keyed1 keyed4 = self.tables.keyed4 - row = testing.db.execute(select([keyed1, keyed4])).first() + row = testing.db.execute( + select([keyed1, keyed4]).select_from(keyed1.join(keyed4, true())) + ).first() eq_(row.b, "b4") eq_(row.q, "q4") eq_(row.a, "a1") @@ -1470,7 +1481,9 @@ class KeyTargetingTest(fixtures.TablesTest): keyed1 = self.tables.keyed1 keyed3 = self.tables.keyed3 - row = testing.db.execute(select([keyed1, keyed3])).first() + row = testing.db.execute( + select([keyed1, keyed3]).select_from(keyed1.join(keyed3, true())) + ).first() eq_(row.q, "c1") # prior to 1.4 #4887, this raised an "ambiguous column name 'a'"" @@ -1493,7 +1506,9 @@ class KeyTargetingTest(fixtures.TablesTest): keyed2 = self.tables.keyed2 row = testing.db.execute( - select([keyed1, keyed2]).apply_labels() + select([keyed1, keyed2]) + .select_from(keyed1.join(keyed2, true())) + .apply_labels() ).first() eq_(row.keyed1_b, "a1") eq_(row.keyed1_a, "a1") @@ -1515,18 +1530,22 @@ class KeyTargetingTest(fixtures.TablesTest): keyed2 = self.tables.keyed2 keyed3 = self.tables.keyed3 - stmt = select( - [ - keyed2.c.a, - keyed3.c.a, - keyed2.c.a, - keyed2.c.a, - keyed3.c.a, - keyed3.c.a, - keyed3.c.d, - keyed3.c.d, - ] - ).apply_labels() + stmt = ( + select( + [ + keyed2.c.a, + keyed3.c.a, + keyed2.c.a, + keyed2.c.a, + keyed3.c.a, + keyed3.c.a, + keyed3.c.d, + keyed3.c.d, + ] + ) + .select_from(keyed2.join(keyed3, true())) + .apply_labels() + ) result = testing.db.execute(stmt) is_false(result._metadata.matched_on_name) |
