diff options
| -rw-r--r-- | lib/sqlalchemy/dialects/mysql/pyodbc.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/mock.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/attributes.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/interfaces.py | 20 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/relationships.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/coercions.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 8 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/exclusions.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/util.py | 26 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/__init__.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/util/compat.py | 10 | ||||
| -rw-r--r-- | setup.cfg | 1 | ||||
| -rw-r--r-- | test/dialect/mssql/test_compiler.py | 4 | ||||
| -rw-r--r-- | test/orm/test_query.py | 103 |
17 files changed, 139 insertions, 83 deletions
diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py index 6a411b984..eb62e6425 100644 --- a/lib/sqlalchemy/dialects/mysql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -126,7 +126,7 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): conn.setdecoding(pyodbc_SQL_CHAR, encoding="utf-8") conn.setdecoding(pyodbc_SQL_WCHAR, encoding="utf-8") conn.setencoding(str, encoding="utf-8") - conn.setencoding(unicode, encoding="utf-8") + conn.setencoding(unicode, encoding="utf-8") # noqa: F821 return on_connect diff --git a/lib/sqlalchemy/engine/mock.py b/lib/sqlalchemy/engine/mock.py index d98d0ee3a..570ee2d04 100644 --- a/lib/sqlalchemy/engine/mock.py +++ b/lib/sqlalchemy/engine/mock.py @@ -108,7 +108,7 @@ def create_mock_engine(url, executor, **kw): # consume dialect arguments from kwargs for k in util.get_cls_kwargs(dialect_cls): if k in kw: - dialect_args[k] = kwargs.pop(k) + dialect_args[k] = kw.pop(k) # create dialect dialect = dialect_cls(**dialect_args) diff --git a/lib/sqlalchemy/orm/attributes.py b/lib/sqlalchemy/orm/attributes.py index 1805bf8db..66a18da99 100644 --- a/lib/sqlalchemy/orm/attributes.py +++ b/lib/sqlalchemy/orm/attributes.py @@ -1586,10 +1586,8 @@ def backref_listeners(attribute, key, uselist): _NO_HISTORY = util.symbol("NO_HISTORY") _NO_STATE_SYMBOLS = frozenset([id(PASSIVE_NO_RESULT), id(NO_VALUE)]) -History = util.namedtuple("History", ["added", "unchanged", "deleted"]) - -class History(History): +class History(util.namedtuple("History", ["added", "unchanged", "deleted"])): """A 3-tuple of added, unchanged and deleted values, representing the changes which have occurred on an instrumented attribute. diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py index 21479c08b..c7b059fda 100644 --- a/lib/sqlalchemy/orm/interfaces.py +++ b/lib/sqlalchemy/orm/interfaces.py @@ -39,6 +39,12 @@ from ..sql import operators from ..sql import visitors from ..sql.traversals import HasCacheKey +if util.TYPE_CHECKING: + from typing import Any + from typing import List + from typing import Optional + from .mapper import Mapper + from .util import AliasedInsp __all__ = ( "EXT_CONTINUE", @@ -363,8 +369,12 @@ class PropComparator(operators.ColumnOperators): __slots__ = "prop", "property", "_parententity", "_adapt_to_entity" - def __init__(self, prop, parentmapper, adapt_to_entity=None): - # type: (MapperProperty, Mapper, Optional(AliasedInsp)) + def __init__( + self, + prop, # type: MapperProperty + parentmapper, # type: Mapper + adapt_to_entity=None, # type: Optional[AliasedInsp] + ): self.prop = self.property = prop self._parententity = adapt_to_entity or parentmapper self._adapt_to_entity = adapt_to_entity @@ -372,8 +382,10 @@ class PropComparator(operators.ColumnOperators): def __clause_element__(self): raise NotImplementedError("%r" % self) - def _bulk_update_tuples(self, value): - # type: (ColumnOperators) -> List[tuple[ColumnOperators, Any]] + def _bulk_update_tuples( + self, value # type: (operators.ColumnOperators) + ): + # type: (...) -> List[tuple[operators.ColumnOperators, Any]] """Receive a SQL expression that represents a value in the SET clause of an UPDATE statement. diff --git a/lib/sqlalchemy/orm/relationships.py b/lib/sqlalchemy/orm/relationships.py index f1764672c..d745500c1 100644 --- a/lib/sqlalchemy/orm/relationships.py +++ b/lib/sqlalchemy/orm/relationships.py @@ -50,6 +50,11 @@ from ..sql.util import selectables_overlap from ..sql.util import visit_binary_product +if util.TYPE_CHECKING: + from .util import AliasedInsp + from typing import Union + + def remote(expr): """Annotate a portion of a primaryjoin expression with a 'remote' annotation. @@ -1859,9 +1864,9 @@ class RelationshipProperty(StrategizedProperty): ) @util.memoized_property - def entity(self): # type: () -> Union[AliasedInsp, Mapper] + def entity(self): # type: () -> Union[AliasedInsp, mapperlib.Mapper] """Return the target mapped entity, which is an inspect() of the - class or aliased class tha is referred towards. + class or aliased class that is referred towards. """ if callable(self.argument) and not isinstance( diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 60456223f..a7324c45f 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -19,9 +19,12 @@ from .visitors import ClauseVisitor from .. import exc from .. import util -coercions = None # type: types.ModuleType -elements = None # type: types.ModuleType -type_api = None # type: types.ModuleType +if util.TYPE_CHECKING: + from types import ModuleType + +coercions = None # type: ModuleType +elements = None # type: ModuleType +type_api = None # type: ModuleType PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT") NO_ARG = util.symbol("NO_ARG") diff --git a/lib/sqlalchemy/sql/coercions.py b/lib/sqlalchemy/sql/coercions.py index 12ec7c750..b3bf4e93b 100644 --- a/lib/sqlalchemy/sql/coercions.py +++ b/lib/sqlalchemy/sql/coercions.py @@ -17,10 +17,13 @@ from .. import inspection from .. import util from ..util import collections_abc -elements = None # type: types.ModuleType -schema = None # type: types.ModuleType -selectable = None # type: types.ModuleType -sqltypes = None # type: types.ModuleType +if util.TYPE_CHECKING: + from types import ModuleType + +elements = None # type: ModuleType +schema = None # type: ModuleType +selectable = None # type: ModuleType +sqltypes = None # type: ModuleType def _is_literal(element): diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index d94d91b16..422eb6220 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -43,6 +43,11 @@ from .. import exc from .. import inspection from .. import util +if util.TYPE_CHECKING: + from typing import Any + from typing import Optional + from typing import Union + def collate(expression, collation): """Return the clause ``expression COLLATE collation``. @@ -709,7 +714,7 @@ class ColumnElement( _alt_names = () def self_group(self, against=None): - # type: (Module, Module, Optional[Any]) -> ClauseEleent + # type: (Optional[Any]) -> ClauseElement if ( against in (operators.and_, operators.or_, operators._asbool) and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 8f5503db0..136c9f868 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -50,6 +50,10 @@ from .visitors import InternalTraversal from .. import exc from .. import util +if util.TYPE_CHECKING: + from typing import Any + from typing import Optional + class _OffsetLimitParam(BindParameter): @property @@ -2096,7 +2100,7 @@ class SelectBase( _memoized_property = util.group_expirable_memoized_property() def _generate_fromclause_column_proxies(self, fromclause): - # type: (FromClause) + # type: (FromClause) -> None raise NotImplementedError() def _refresh_for_new_column(self, column): @@ -2344,7 +2348,7 @@ class SelectStatementGrouping(GroupedElement, SelectBase): _is_select_container = True def __init__(self, element): - # type: (SelectBase) + # type: (SelectBase) -> None self.element = coercions.expect(roles.SelectStatementRole, element) @property diff --git a/lib/sqlalchemy/testing/__init__.py b/lib/sqlalchemy/testing/__init__.py index e5425dd81..ab1198da8 100644 --- a/lib/sqlalchemy/testing/__init__.py +++ b/lib/sqlalchemy/testing/__init__.py @@ -54,6 +54,7 @@ from .util import adict # noqa from .util import fail # noqa from .util import flag_combinations # noqa from .util import force_drop_names # noqa +from .util import lambda_combinations # noqa from .util import metadata_fixture # noqa from .util import provide_metadata # noqa from .util import resolve_lambda # noqa diff --git a/lib/sqlalchemy/testing/exclusions.py b/lib/sqlalchemy/testing/exclusions.py index 8b17f64c7..b2828b107 100644 --- a/lib/sqlalchemy/testing/exclusions.py +++ b/lib/sqlalchemy/testing/exclusions.py @@ -159,7 +159,9 @@ class compound(object): for fail in self.fails: if self._check_combinations(combination, fail) and fail(config): if util.py2k: - str_ex = unicode(ex).encode("utf-8", errors="ignore") + str_ex = unicode(ex).encode( # noqa: F821 + "utf-8", errors="ignore" + ) else: str_ex = str(ex) print( diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index 74c9b1aeb..de20bb794 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -12,13 +12,14 @@ import sys import time import types +from . import mock from ..util import decorator from ..util import defaultdict +from ..util import inspect_getfullargspec from ..util import jython from ..util import py2k from ..util import pypy - if jython: def jython_gc_collect(*args): @@ -276,6 +277,25 @@ def flag_combinations(*combinations): ) +def lambda_combinations(lambda_arg_sets, **kw): + from . import config + + args = inspect_getfullargspec(lambda_arg_sets) + + arg_sets = lambda_arg_sets(*[mock.Mock() for arg in args[0]]) + + def create_fixture(pos): + def fixture(**kw): + return lambda_arg_sets(**kw)[pos] + + fixture.__name__ = "fixture_%3.3d" % pos + return fixture + + return config.combinations( + *[(create_fixture(i),) for i in range(len(arg_sets))], **kw + ) + + def resolve_lambda(__fn, **kw): """Given a no-arg lambda and a namespace, return a new lambda that has all the values filled in. @@ -285,10 +305,12 @@ def resolve_lambda(__fn, **kw): """ + pos_args = inspect_getfullargspec(__fn)[0] + pass_pos_args = {arg: kw.pop(arg) for arg in pos_args} glb = dict(__fn.__globals__) glb.update(kw) new_fn = types.FunctionType(__fn.__code__, glb) - return new_fn() + return new_fn(**pass_pos_args) def metadata_fixture(ddl="function"): diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py index c1790439c..a19636f62 100644 --- a/lib/sqlalchemy/util/__init__.py +++ b/lib/sqlalchemy/util/__init__.py @@ -80,6 +80,7 @@ from .compat import StringIO # noqa from .compat import text_type # noqa from .compat import threading # noqa from .compat import timezone # noqa +from .compat import TYPE_CHECKING # noqa from .compat import u # noqa from .compat import ue # noqa from .compat import unquote # noqa diff --git a/lib/sqlalchemy/util/compat.py b/lib/sqlalchemy/util/compat.py index 82bf68d2c..2cb5db5d4 100644 --- a/lib/sqlalchemy/util/compat.py +++ b/lib/sqlalchemy/util/compat.py @@ -158,6 +158,8 @@ if py3k: def ue(s): return s + from typing import TYPE_CHECKING + # Unused. Kept for backwards compatibility. callable = callable # noqa else: @@ -244,8 +246,10 @@ else: def safe_bytestring(text): # py2k only if not isinstance(text, string_types): - return unicode(text).encode("ascii", errors="backslashreplace") - elif isinstance(text, unicode): + return unicode(text).encode( # noqa: F821 + "ascii", errors="backslashreplace" + ) + elif isinstance(text, unicode): # noqa: F821 return text.encode("ascii", errors="backslashreplace") else: return text @@ -259,6 +263,8 @@ else: " raise tp, value, tb\n" ) + TYPE_CHECKING = False + if py35: from inspect import formatannotation @@ -20,7 +20,6 @@ ignore = A003, D, E203,E305,E711,E712,E721,E722,E741, - F821 N801,N802,N806, RST304,RST303,RST299,RST399, W503,W504 diff --git a/test/dialect/mssql/test_compiler.py b/test/dialect/mssql/test_compiler.py index 9d46c3f35..bb5199b00 100644 --- a/test/dialect/mssql/test_compiler.py +++ b/test/dialect/mssql/test_compiler.py @@ -315,7 +315,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): }, ), ( - lambda: select([t]).where(t.c.foo.in_(["x", "y", "z"])), + lambda t: select([t]).where(t.c.foo.in_(["x", "y", "z"])), "SELECT sometable.foo FROM sometable WHERE sometable.foo " "IN ([POSTCOMPILE_foo_1])", { @@ -323,7 +323,7 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL): "check_post_param": {}, }, ), - (lambda: t.c.foo.in_([None]), "sometable.foo IN (NULL)", {}), + (lambda t: t.c.foo.in_([None]), "sometable.foo IN (NULL)", {}), ) def test_strict_binds(self, expr, compiled, kw): """test the 'strict' compiler binds.""" diff --git a/test/orm/test_query.py b/test/orm/test_query.py index 558b4d91c..271d85dd6 100644 --- a/test/orm/test_query.py +++ b/test/orm/test_query.py @@ -161,7 +161,7 @@ class RowTupleTest(QueryTest): is_(sess.query(ex)._deep_entity_zero(), inspect(User)) @testing.combinations( - lambda: ( + lambda sess, User: ( sess.query(User), [ { @@ -173,7 +173,7 @@ class RowTupleTest(QueryTest): } ], ), - lambda: ( + lambda sess, User, users: ( sess.query(User.id, User), [ { @@ -192,7 +192,7 @@ class RowTupleTest(QueryTest): }, ], ), - lambda: ( + lambda sess, User, user_alias, users: ( sess.query(User.id, user_alias), [ { @@ -211,7 +211,7 @@ class RowTupleTest(QueryTest): }, ], ), - lambda: ( + lambda sess, user_alias, users: ( sess.query(user_alias.id), [ { @@ -223,7 +223,7 @@ class RowTupleTest(QueryTest): } ], ), - lambda: ( + lambda sess, user_alias_id_label, users, user_alias: ( sess.query(user_alias_id_label), [ { @@ -235,7 +235,7 @@ class RowTupleTest(QueryTest): } ], ), - lambda: ( + lambda sess, address_alias, Address: ( sess.query(address_alias), [ { @@ -247,7 +247,7 @@ class RowTupleTest(QueryTest): } ], ), - lambda: ( + lambda sess, name_label, fn, users, User: ( sess.query(name_label, fn), [ { @@ -266,7 +266,7 @@ class RowTupleTest(QueryTest): }, ], ), - lambda: ( + lambda sess, cte: ( sess.query(cte), [ { @@ -278,7 +278,7 @@ class RowTupleTest(QueryTest): } ], ), - lambda: ( + lambda sess, subq1: ( sess.query(subq1.c.id), [ { @@ -290,7 +290,7 @@ class RowTupleTest(QueryTest): } ], ), - lambda: ( + lambda sess, subq2: ( sess.query(subq2.c.id), [ { @@ -302,7 +302,7 @@ class RowTupleTest(QueryTest): } ], ), - lambda: ( + lambda sess, users: ( sess.query(users), [ { @@ -321,7 +321,7 @@ class RowTupleTest(QueryTest): }, ], ), - lambda: ( + lambda sess, users: ( sess.query(users.c.name), [ { @@ -333,7 +333,7 @@ class RowTupleTest(QueryTest): } ], ), - lambda: ( + lambda sess, bundle, User: ( sess.query(bundle), [ { @@ -968,9 +968,9 @@ class GetTest(QueryTest): class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): @testing.combinations( - lambda: s.query(User).limit(2), - lambda: s.query(User).filter(User.id == 1).offset(2), - lambda: s.query(User).limit(2).offset(2), + lambda s, User: s.query(User).limit(2), + lambda s, User: s.query(User).filter(User.id == 1).offset(2), + lambda s, User: s.query(User).limit(2).offset(2), ) def test_no_limit_offset(self, test_case): User = self.classes.User @@ -1128,11 +1128,11 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): is_(q1._entity_zero(), inspect(User)) @testing.combinations( - lambda: s.query(User).filter(User.id == 5), - lambda: s.query(User).filter_by(id=5), - lambda: s.query(User).limit(5), - lambda: s.query(User).group_by(User.name), - lambda: s.query(User).order_by(User.name), + lambda s, User: s.query(User).filter(User.id == 5), + lambda s, User: s.query(User).filter_by(id=5), + lambda s, User: s.query(User).limit(5), + lambda s, User: s.query(User).group_by(User.name), + lambda s, User: s.query(User).order_by(User.name), ) def test_from_statement(self, test_case): User = self.classes.User @@ -1144,11 +1144,11 @@ class InvalidGenerationsTest(QueryTest, AssertsCompiledSQL): assert_raises(sa_exc.InvalidRequestError, q.from_statement, text("x")) @testing.combinations( - (Query.filter, lambda: meth(User.id == 5)), - (Query.filter_by, lambda: meth(id=5)), - (Query.limit, lambda: meth(5)), - (Query.group_by, lambda: meth(User.name)), - (Query.order_by, lambda: meth(User.name)), + (Query.filter, lambda meth, User: meth(User.id == 5)), + (Query.filter_by, lambda meth: meth(id=5)), + (Query.limit, lambda meth: meth(5)), + (Query.group_by, lambda meth, User: meth(User.name)), + (Query.order_by, lambda meth, User: meth(User.name)), ) def test_from_statement_text(self, meth, test_case): @@ -1251,13 +1251,13 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): id_="ar", ) @testing.combinations( - (lambda: 5, lambda: User.id, ":id_1 %s users.id"), + (lambda User: 5, lambda User: User.id, ":id_1 %s users.id"), (lambda: 5, lambda: literal(6), ":param_1 %s :param_2"), - (lambda: User.id, lambda: 5, "users.id %s :id_1"), - (lambda: User.id, lambda: literal("b"), "users.id %s :param_1"), - (lambda: User.id, lambda: User.id, "users.id %s users.id"), + (lambda User: User.id, lambda: 5, "users.id %s :id_1"), + (lambda User: User.id, lambda: literal("b"), "users.id %s :param_1"), + (lambda User: User.id, lambda User: User.id, "users.id %s users.id"), (lambda: literal(5), lambda: "b", ":param_1 %s :param_2"), - (lambda: literal(5), lambda: User.id, ":param_1 %s users.id"), + (lambda: literal(5), lambda User: User.id, ":param_1 %s users.id"), (lambda: literal(5), lambda: literal(6), ":param_1 %s :param_2"), argnames="lhs, rhs, res", id_="aar", @@ -1280,35 +1280,30 @@ class OperatorTest(QueryTest, AssertsCompiledSQL): id_="arr", argnames="py_op, fwd_op, rev_op", ) - @testing.combinations( - (lambda: "a", lambda: User.id, ":id_1", "users.id"), - ( - lambda: "a", - lambda: literal("b"), - ":param_2", - ":param_1", - ), # note swap! - (lambda: User.id, lambda: "b", "users.id", ":id_1"), - (lambda: User.id, lambda: literal("b"), "users.id", ":param_1"), - (lambda: User.id, lambda: User.id, "users.id", "users.id"), - (lambda: literal("a"), lambda: "b", ":param_1", ":param_2"), - (lambda: literal("a"), lambda: User.id, ":param_1", "users.id"), - (lambda: literal("a"), lambda: literal("b"), ":param_1", ":param_2"), - (lambda: ualias.id, lambda: literal("b"), "users_1.id", ":param_1"), - (lambda: User.id, lambda: ualias.name, "users.id", "users_1.name"), - (lambda: User.name, lambda: ualias.name, "users.name", "users_1.name"), - (lambda: ualias.name, lambda: User.name, "users_1.name", "users.name"), - argnames="lhs, rhs, l_sql, r_sql", - id_="aarr", + @testing.lambda_combinations( + lambda User, ualias: ( + ("a", User.id, ":id_1", "users.id"), + ("a", literal("b"), ":param_2", ":param_1"), # note swap! + (User.id, "b", "users.id", ":id_1"), + (User.id, literal("b"), "users.id", ":param_1"), + (User.id, User.id, "users.id", "users.id"), + (literal("a"), "b", ":param_1", ":param_2"), + (literal("a"), User.id, ":param_1", "users.id"), + (literal("a"), literal("b"), ":param_1", ":param_2"), + (ualias.id, literal("b"), "users_1.id", ":param_1"), + (User.id, ualias.name, "users.id", "users_1.name"), + (User.name, ualias.name, "users.name", "users_1.name"), + (ualias.name, User.name, "users_1.name", "users.name"), + ), + argnames="fixture", ) - def test_comparison(self, py_op, fwd_op, rev_op, lhs, rhs, l_sql, r_sql): + def test_comparison(self, py_op, fwd_op, rev_op, fixture): User = self.classes.User create_session().query(User) ualias = aliased(User) - lhs = testing.resolve_lambda(lhs, User=User, ualias=ualias) - rhs = testing.resolve_lambda(rhs, User=User, ualias=ualias) + lhs, rhs, l_sql, r_sql = fixture(User=User, ualias=ualias) # the compiled clause should match either (e.g.): # 'a' < 'b' -or- 'b' > 'a'. |
