diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/base.py | 15 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/oracle/base.py | 13 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/base.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/create.py | 16 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/__init__.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 176 |
7 files changed, 211 insertions, 23 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 8241d951b..6e84c9da1 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1434,7 +1434,10 @@ class MySQLCompiler(compiler.SQLCompiler): else: return "" - def visit_join(self, join, asfrom=False, **kwargs): + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + if from_linter: + from_linter.edges.add((join.left, join.right)) + if join.full: join_type = " FULL OUTER JOIN " elif join.isouter: @@ -1444,11 +1447,15 @@ class MySQLCompiler(compiler.SQLCompiler): return "".join( ( - self.process(join.left, asfrom=True, **kwargs), + self.process( + join.left, asfrom=True, from_linter=from_linter, **kwargs + ), join_type, - self.process(join.right, asfrom=True, **kwargs), + self.process( + join.right, asfrom=True, from_linter=from_linter, **kwargs + ), " ON ", - self.process(join.onclause, **kwargs), + self.process(join.onclause, from_linter=from_linter, **kwargs), ) ) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index 9cb25b934..87e0baa58 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -829,19 +829,24 @@ class OracleCompiler(compiler.SQLCompiler): return " FROM DUAL" - def visit_join(self, join, **kwargs): + def visit_join(self, join, from_linter=None, **kwargs): if self.dialect.use_ansi: - return compiler.SQLCompiler.visit_join(self, join, **kwargs) + return compiler.SQLCompiler.visit_join( + self, join, from_linter=from_linter, **kwargs + ) else: + if from_linter: + from_linter.edges.add((join.left, join.right)) + kwargs["asfrom"] = True if isinstance(join.right, expression.FromGrouping): right = join.right.element else: right = join.right return ( - self.process(join.left, **kwargs) + self.process(join.left, from_linter=from_linter, **kwargs) + ", " - + self.process(right, **kwargs) + + self.process(right, from_linter=from_linter, **kwargs) ) def _get_nonansi_join_whereclause(self, froms): diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 88558df5d..462e5f9ec 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -16,6 +16,7 @@ from .. import exc from .. import inspection from .. import log from .. import util +from ..sql import compiler from ..sql import schema from ..sql import util as sql_util @@ -1083,6 +1084,8 @@ class Connection(Connectable): schema_translate_map=self.schema_for_object if not self.schema_for_object.is_default else None, + linting=self.dialect.compiler_linting + | compiler.WARN_LINTING, ) self._execution_options["compiled_cache"][key] = compiled_sql else: @@ -1093,6 +1096,7 @@ class Connection(Connectable): schema_translate_map=self.schema_for_object if not self.schema_for_object.is_default else None, + linting=self.dialect.compiler_linting | compiler.WARN_LINTING, ) ret = self._execute_context( diff --git a/lib/sqlalchemy/engine/create.py b/lib/sqlalchemy/engine/create.py index 58fe91c7e..5198c8cd6 100644 --- a/lib/sqlalchemy/engine/create.py +++ b/lib/sqlalchemy/engine/create.py @@ -13,6 +13,7 @@ from .. import event from .. import exc from .. import pool as poollib from .. import util +from ..sql import compiler @util.deprecated_params( @@ -142,6 +143,16 @@ def create_engine(url, **kwargs): :param empty_in_strategy: No longer used; SQLAlchemy now uses "empty set" behavior for IN in all cases. + :param enable_from_linting: defaults to True. Will emit a warning + if a given SELECT statement is found to have un-linked FROM elements + which would cause a cartesian product. + + .. versionadded:: 1.4 + + .. seealso:: + + :ref:`change_4737` + :param encoding: Defaults to ``utf-8``. This is the string encoding used by SQLAlchemy for string encode/decode operations which occur within SQLAlchemy, **outside of @@ -446,6 +457,11 @@ def create_engine(url, **kwargs): dialect_args["dbapi"] = dbapi + dialect_args.setdefault("compiler_linting", compiler.NO_LINTING) + enable_from_linting = kwargs.pop("enable_from_linting", True) + if enable_from_linting: + dialect_args["compiler_linting"] ^= compiler.COLLECT_CARTESIAN_PRODUCTS + for plugin in plugins: plugin.handle_dialect_kwargs(dialect_cls, dialect_args) diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index 1c995f05f..378890444 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -31,7 +31,6 @@ from ..sql import expression from ..sql import schema from ..sql.elements import quoted_name - AUTOCOMMIT_REGEXP = re.compile( r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER)", re.I | re.UNICODE ) @@ -214,6 +213,9 @@ class DefaultDialect(interfaces.Dialect): supports_native_boolean=None, max_identifier_length=None, label_length=None, + # int() is because the @deprecated_params decorator cannot accommodate + # the direct reference to the "NO_LINTING" object + compiler_linting=int(compiler.NO_LINTING), **kwargs ): @@ -249,7 +251,7 @@ class DefaultDialect(interfaces.Dialect): self._user_defined_max_identifier_length ) self.label_length = label_length - + self.compiler_linting = compiler_linting if self.description_encoding == "use_encoding": self._description_decoder = ( processors.to_unicode_processor_factory diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 6554faaa0..488717041 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -5,6 +5,10 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php +from .compiler import COLLECT_CARTESIAN_PRODUCTS # noqa +from .compiler import FROM_LINTING # noqa +from .compiler import NO_LINTING # noqa +from .compiler import WARN_LINTING # noqa from .expression import Alias # noqa from .expression import alias # noqa from .expression import all_ # noqa diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 8499484f3..ed463ebe3 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -41,7 +41,6 @@ from .base import NO_ARG from .. import exc from .. import util - RESERVED_WORDS = set( [ "all", @@ -270,6 +269,89 @@ ExpandedState = collections.namedtuple( ) +NO_LINTING = util.symbol("NO_LINTING", "Disable all linting.", canonical=0) + +COLLECT_CARTESIAN_PRODUCTS = util.symbol( + "COLLECT_CARTESIAN_PRODUCTS", + "Collect data on FROMs and cartesian products and gather " + "into 'self.from_linter'", + canonical=1, +) + +WARN_LINTING = util.symbol( + "WARN_LINTING", "Emit warnings for linters that find problems", canonical=2 +) + +FROM_LINTING = util.symbol( + "FROM_LINTING", + "Warn for cartesian products; " + "combines COLLECT_CARTESIAN_PRODUCTS and WARN_LINTING", + canonical=COLLECT_CARTESIAN_PRODUCTS | WARN_LINTING, +) + + +class FromLinter(collections.namedtuple("FromLinter", ["froms", "edges"])): + def lint(self, start=None): + froms = self.froms + if not froms: + return None, None + + edges = set(self.edges) + the_rest = set(froms) + + if start is not None: + start_with = start + the_rest.remove(start_with) + else: + start_with = the_rest.pop() + + stack = collections.deque([start_with]) + + while stack and the_rest: + node = stack.popleft() + the_rest.discard(node) + + # comparison of nodes in edges here is based on hash equality, as + # there are "annotated" elements that match the non-annotated ones. + # to remove the need for in-python hash() calls, use native + # containment routines (e.g. "node in edge", "edge.index(node)") + to_remove = {edge for edge in edges if node in edge} + + # appendleft the node in each edge that is not + # the one that matched. + stack.extendleft(edge[not edge.index(node)] for edge in to_remove) + edges.difference_update(to_remove) + + # FROMS left over? boom + if the_rest: + return the_rest, start_with + else: + return None, None + + def warn(self): + the_rest, start_with = self.lint() + + # FROMS left over? boom + if the_rest: + + froms = the_rest + if froms: + template = ( + "SELECT statement has a cartesian product between " + "FROM element(s) {froms} and " + 'FROM element "{start}". Apply join condition(s) ' + "between each element to resolve." + ) + froms_str = ", ".join( + '"{elem}"'.format(elem=self.froms[from_]) + for from_ in froms + ) + message = template.format( + froms=froms_str, start=self.froms[start_with] + ) + util.warn(message) + + class Compiled(object): """Represent a compiled SQL or DDL expression. @@ -568,7 +650,13 @@ class SQLCompiler(Compiled): insert_prefetch = update_prefetch = () def __init__( - self, dialect, statement, column_keys=None, inline=False, **kwargs + self, + dialect, + statement, + column_keys=None, + inline=False, + linting=NO_LINTING, + **kwargs ): """Construct a new :class:`.SQLCompiler` object. @@ -592,6 +680,8 @@ class SQLCompiler(Compiled): # execute) self.inline = inline or getattr(statement, "inline", False) + self.linting = linting + # a dictionary of bind parameter keys to BindParameter # instances. self.binds = {} @@ -1547,9 +1637,21 @@ class SQLCompiler(Compiled): return to_update, replacement_expression def visit_binary( - self, binary, override_operator=None, eager_grouping=False, **kw + self, + binary, + override_operator=None, + eager_grouping=False, + from_linter=None, + **kw ): + if from_linter and operators.is_comparison(binary.operator): + from_linter.edges.update( + itertools.product( + binary.left._from_objects, binary.right._from_objects + ) + ) + # don't allow "? = ?" to render if ( self.ansi_bind_rules @@ -1568,7 +1670,9 @@ class SQLCompiler(Compiled): except KeyError: raise exc.UnsupportedCompilationError(self, operator_) else: - return self._generate_generic_binary(binary, opstring, **kw) + return self._generate_generic_binary( + binary, opstring, from_linter=from_linter, **kw + ) def visit_function_as_comparison_op_binary(self, element, operator, **kw): return self.process(element.sql_function, **kw) @@ -1916,6 +2020,7 @@ class SQLCompiler(Compiled): ashint=False, fromhints=None, visiting_cte=None, + from_linter=None, **kwargs ): self._init_cte_state() @@ -2021,6 +2126,9 @@ class SQLCompiler(Compiled): self.ctes[cte] = text if asfrom: + if from_linter: + from_linter.froms[cte] = cte_name + if not is_new_cte and embedded_in_current_named_cte: return self.preparer.format_alias(cte, cte_name) @@ -2043,6 +2151,7 @@ class SQLCompiler(Compiled): subquery=False, lateral=False, enclosing_alias=None, + from_linter=None, **kwargs ): if enclosing_alias is not None and enclosing_alias.element is alias: @@ -2071,6 +2180,9 @@ class SQLCompiler(Compiled): if ashint: return self.preparer.format_alias(alias, alias_name) elif asfrom: + if from_linter: + from_linter.froms[alias] = alias_name + inner = alias.element._compiler_dispatch( self, asfrom=True, lateral=lateral, **kwargs ) @@ -2284,6 +2396,7 @@ class SQLCompiler(Compiled): compound_index=0, select_wraps_for=None, lateral=False, + from_linter=None, **kwargs ): @@ -2373,7 +2486,7 @@ class SQLCompiler(Compiled): ] text = self._compose_select_body( - text, select, inner_columns, froms, byfrom, kwargs + text, select, inner_columns, froms, byfrom, toplevel, kwargs ) if select._statement_hints: @@ -2465,10 +2578,17 @@ class SQLCompiler(Compiled): return froms def _compose_select_body( - self, text, select, inner_columns, froms, byfrom, kwargs + self, text, select, inner_columns, froms, byfrom, toplevel, kwargs ): text += ", ".join(inner_columns) + if self.linting & COLLECT_CARTESIAN_PRODUCTS: + from_linter = FromLinter({}, set()) + if toplevel: + self.from_linter = from_linter + else: + from_linter = None + if froms: text += " \nFROM " @@ -2476,7 +2596,11 @@ class SQLCompiler(Compiled): text += ", ".join( [ f._compiler_dispatch( - self, asfrom=True, fromhints=byfrom, **kwargs + self, + asfrom=True, + fromhints=byfrom, + from_linter=from_linter, + **kwargs ) for f in froms ] @@ -2484,7 +2608,12 @@ class SQLCompiler(Compiled): else: text += ", ".join( [ - f._compiler_dispatch(self, asfrom=True, **kwargs) + f._compiler_dispatch( + self, + asfrom=True, + from_linter=from_linter, + **kwargs + ) for f in froms ] ) @@ -2492,10 +2621,18 @@ class SQLCompiler(Compiled): text += self.default_from() if select._whereclause is not None: - t = select._whereclause._compiler_dispatch(self, **kwargs) + t = select._whereclause._compiler_dispatch( + self, from_linter=from_linter, **kwargs + ) if t: text += " \nWHERE " + t + if ( + self.linting & COLLECT_CARTESIAN_PRODUCTS + and self.linting & WARN_LINTING + ): + from_linter.warn() + if select._group_by_clause.clauses: text += self.group_by_clause(select, **kwargs) @@ -2597,8 +2734,12 @@ class SQLCompiler(Compiled): ashint=False, fromhints=None, use_schema=True, + from_linter=None, **kwargs ): + if from_linter: + from_linter.froms[table] = table.fullname + if asfrom or ashint: effective_schema = self.preparer.schema_for_object(table) @@ -2618,7 +2759,10 @@ class SQLCompiler(Compiled): else: return "" - def visit_join(self, join, asfrom=False, **kwargs): + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + if from_linter: + from_linter.edges.add((join.left, join.right)) + if join.full: join_type = " FULL OUTER JOIN " elif join.isouter: @@ -2626,12 +2770,18 @@ class SQLCompiler(Compiled): else: join_type = " JOIN " return ( - join.left._compiler_dispatch(self, asfrom=True, **kwargs) + join.left._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + join_type - + join.right._compiler_dispatch(self, asfrom=True, **kwargs) + + join.right._compiler_dispatch( + self, asfrom=True, from_linter=from_linter, **kwargs + ) + " ON " # TODO: likely need asfrom=True here? - + join.onclause._compiler_dispatch(self, **kwargs) + + join.onclause._compiler_dispatch( + self, from_linter=from_linter, **kwargs + ) ) def _setup_crud_hints(self, stmt, table_text): |
