diff options
| author | Federico Caselli <cfederico87@gmail.com> | 2020-03-07 19:17:07 +0100 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2020-03-07 17:50:45 -0500 |
| commit | eda6dbbf387def2063d1b6719b64b20f9e7f2ab4 (patch) | |
| tree | 4af5f41edfac169b0fdc6d6cab0fce4e8bf776cf /lib/sqlalchemy/sql | |
| parent | 851fb8f5a661c66ee76308181118369c8c4df9e0 (diff) | |
| download | sqlalchemy-eda6dbbf387def2063d1b6719b64b20f9e7f2ab4.tar.gz | |
Simplified module pre-loading strategy and made it linter friendly
Introduced a modules registry to register modules that should be lazily loaded
in the package init. This ensures that they are in the system module cache,
avoiding potential thread safety issues as when importing them directly
in the function that uses them. The module registry is used to obtain
these modules directly, ensuring that the all the lazily loaded modules
are resolved at the proper time
This replaces dependency_for decorator and the dependencies decorator logic,
removing the need to pass the resolved modules as arguments of the
decodated functions and removes possible errors caused by linters.
Fixes: #4689
Fixes: #4656
Change-Id: I2e291eba4297867fc0ddb5d875b9f7af34751d01
Diffstat (limited to 'lib/sqlalchemy/sql')
| -rw-r--r-- | lib/sqlalchemy/sql/__init__.py | 2 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 15 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/schema.py | 26 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 59 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 23 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/traversals.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/type_api.py | 15 |
9 files changed, 91 insertions, 76 deletions
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py index 488717041..281b7d0f2 100644 --- a/lib/sqlalchemy/sql/__init__.py +++ b/lib/sqlalchemy/sql/__init__.py @@ -120,7 +120,7 @@ def __go(lcls): _prepare_annotations(FromClause, AnnotatedFromClause) _prepare_annotations(ClauseList, Annotated) - _sa_util.dependencies.resolve_all("sqlalchemy.sql") + _sa_util.preloaded.import_prefix("sqlalchemy.sql") from . import naming # noqa diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 89839ea28..b61c7dc5e 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -209,6 +209,14 @@ class _DialectArgDict(util.collections_abc.MutableMapping): del self._non_defaults[key] +@util.preload_module("sqlalchemy.dialects") +def _kw_reg_for_dialect(dialect_name): + dialect_cls = util.preloaded.dialects.registry.load(dialect_name) + if dialect_cls.construct_arguments is None: + return None + return dict(dialect_cls.construct_arguments) + + class DialectKWArgs(object): """Establish the ability for a class to have dialect-specific arguments with defaults and constructor validation. @@ -307,13 +315,6 @@ class DialectKWArgs(object): """A synonym for :attr:`.DialectKWArgs.dialect_kwargs`.""" return self.dialect_kwargs - @util.dependencies("sqlalchemy.dialects") - def _kw_reg_for_dialect(dialects, dialect_name): - dialect_cls = dialects.registry.load(dialect_name) - if dialect_cls.construct_arguments is None: - return None - return dict(dialect_cls.construct_arguments) - _kw_registry = util.PopulateDict(_kw_reg_for_dialect) def _kw_reg_for_dialect_cls(self, dialect_name): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 3ebcf24b0..b37c46216 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -1022,9 +1022,10 @@ class SQLCompiler(Compiled): return expanded_state - @util.dependencies("sqlalchemy.engine.result") - def _create_result_map(self, result): + @util.preload_module("sqlalchemy.engine.result") + def _create_result_map(self): """utility method used for unit tests only.""" + result = util.preloaded.engine_result return result.CursorResultMetaData._create_description_match_map( self._result_columns ) @@ -4127,8 +4128,10 @@ class IdentifierPreparer(object): ident = self.quote_identifier(ident) return ident - @util.dependencies("sqlalchemy.sql.naming") - def format_constraint(self, naming, constraint, _alembic_quote=True): + @util.preload_module("sqlalchemy.sql.naming") + def format_constraint(self, constraint, _alembic_quote=True): + naming = util.preloaded.sql_naming + if isinstance(constraint.name, elements._defer_name): name = naming._constraint_name_for_table( constraint, constraint.table diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index 47739a37d..bb68e8a7e 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -422,8 +422,8 @@ class ClauseElement( return self - @util.dependencies("sqlalchemy.engine.default") - def compile(self, default, bind=None, dialect=None, **kw): + @util.preload_module("sqlalchemy.engine.default") + def compile(self, bind=None, dialect=None, **kw): """Compile this SQL expression. The return value is a :class:`~.Compiled` object. @@ -477,6 +477,7 @@ class ClauseElement( """ + default = util.preloaded.engine_default if not dialect: if bind: dialect = bind.dialect @@ -1782,8 +1783,8 @@ class TextClause( else: new_params[key] = existing._with_value(value) - @util.dependencies("sqlalchemy.sql.selectable") - def columns(self, selectable, *cols, **types): + @util.preload_module("sqlalchemy.sql.selectable") + def columns(self, *cols, **types): r"""Turn this :class:`.TextClause` object into a :class:`.TextualSelect` object that serves the same role as a SELECT statement. @@ -1888,6 +1889,7 @@ class TextClause( argument as it also indicates positional ordering. """ + selectable = util.preloaded.sql_selectable positional_input_cols = [ ColumnClause(col.key, types.pop(col.key)) if col.key in types diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index 4c627c4cc..69f60ba24 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -2193,8 +2193,10 @@ class ColumnDefault(DefaultGenerator): ) @util.memoized_property - @util.dependencies("sqlalchemy.sql.sqltypes") - def _arg_is_typed(self, sqltypes): + @util.preload_module("sqlalchemy.sql.sqltypes") + def _arg_is_typed(self): + sqltypes = util.preloaded.sql_sqltypes + if self.is_clause_element: return not isinstance(self.arg.type, sqltypes.NullType) else: @@ -2440,14 +2442,16 @@ class Sequence(roles.StatementRole, DefaultGenerator): def is_clause_element(self): return False - @util.dependencies("sqlalchemy.sql.functions.func") - def next_value(self, func): + @util.preload_module("sqlalchemy.sql.functions") + def next_value(self): """Return a :class:`.next_value` function element which will render the appropriate increment function for this :class:`.Sequence` within any SQL expression. """ - return func.next_value(self, bind=self.bind) + return util.preloaded.sql_functions.func.next_value( + self, bind=self.bind + ) def _set_parent(self, column): super(Sequence, self)._set_parent(column) @@ -3925,10 +3929,10 @@ class MetaData(SchemaItem): """ return self._bind - @util.dependencies("sqlalchemy.engine.url") - def _bind_to(self, url, bind): + @util.preload_module("sqlalchemy.engine.url") + def _bind_to(self, bind): """Bind this MetaData to an Engine, Connection, string or URL.""" - + url = util.preloaded.engine_url if isinstance(bind, util.string_types + (url.URL,)): self._bind = sqlalchemy.create_engine(bind) else: @@ -4231,10 +4235,10 @@ class ThreadLocalMetaData(MetaData): return getattr(self.context, "_engine", None) - @util.dependencies("sqlalchemy.engine.url") - def _bind_to(self, url, bind): + @util.preload_module("sqlalchemy.engine.url") + def _bind_to(self, bind): """Bind to a Connectable in the caller's thread.""" - + url = util.preloaded.engine_url if isinstance(bind, util.string_types + (url.URL,)): try: self.context._engine = self.__engines[bind] diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index 965ac6e7f..5536b27bc 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -164,14 +164,13 @@ class Selectable(ReturnsRows): "deprecated, and will be removed in a future release. Similar " "functionality is available via the sqlalchemy.sql.visitors module.", ) - @util.dependencies("sqlalchemy.sql.util") - def replace_selectable(self, sqlutil, old, alias): + @util.preload_module("sqlalchemy.sql.util") + def replace_selectable(self, old, alias): """replace all occurrences of FromClause 'old' with the given Alias object, returning a copy of this :class:`.FromClause`. """ - - return sqlutil.ClauseAdapter(alias).traverse(self) + return util.preloaded.sql_util.ClauseAdapter(alias).traverse(self) def corresponding_column(self, column, require_embedded=False): """Given a :class:`.ColumnElement`, return the exported @@ -358,8 +357,8 @@ class FromClause(HasMemoized, roles.AnonymizedFromClauseRole, Selectable): ":class:`.functions.count` function available from the " ":attr:`.func` namespace.", ) - @util.dependencies("sqlalchemy.sql.functions") - def count(self, functions, whereclause=None, **params): + @util.preload_module("sqlalchemy.sql.functions") + def count(self, whereclause=None, **params): """return a SELECT COUNT generated against this :class:`.FromClause`. @@ -368,7 +367,7 @@ class FromClause(HasMemoized, roles.AnonymizedFromClauseRole, Selectable): :class:`.functions.count` """ - + functions = util.preloaded.sql_functions if self.primary_key: col = list(self.primary_key)[0] else: @@ -801,8 +800,9 @@ class Join(FromClause): def self_group(self, against=None): return FromGrouping(self) - @util.dependencies("sqlalchemy.sql.util") - def _populate_column_collection(self, sqlutil): + @util.preload_module("sqlalchemy.sql.util") + def _populate_column_collection(self): + sqlutil = util.preloaded.sql_util columns = [c for c in self.left.columns] + [ c for c in self.right.columns ] @@ -1033,8 +1033,8 @@ class Join(FromClause): def bind(self): return self.left.bind or self.right.bind - @util.dependencies("sqlalchemy.sql.util") - def alias(self, sqlutil, name=None, flat=False): + @util.preload_module("sqlalchemy.sql.util") + def alias(self, name=None, flat=False): r"""return an alias of this :class:`.Join`. The default behavior here is to first produce a SELECT @@ -1134,6 +1134,7 @@ class Join(FromClause): :func:`~.expression.alias` """ + sqlutil = util.preloaded.sql_util if flat: assert name is None, "Can't send name argument with flat" left_a, right_a = ( @@ -1458,8 +1459,9 @@ class TableSample(AliasedReturnsRows): self.seed = seed super(TableSample, self)._init(selectable, name=name) - @util.dependencies("sqlalchemy.sql.functions") - def _get_method(self, functions): + @util.preload_module("sqlalchemy.sql.functions") + def _get_method(self): + functions = util.preloaded.sql_functions if isinstance(self.sampling, functions.Function): return self.sampling else: @@ -1929,8 +1931,8 @@ class TableClause(Immutable, FromClause): self._columns.add(c) c.table = self - @util.dependencies("sqlalchemy.sql.dml") - def insert(self, dml, values=None, inline=False, **kwargs): + @util.preload_module("sqlalchemy.sql.dml") + def insert(self, values=None, inline=False, **kwargs): """Generate an :func:`.insert` construct against this :class:`.TableClause`. @@ -1941,13 +1943,12 @@ class TableClause(Immutable, FromClause): See :func:`.insert` for argument and usage information. """ + return util.preloaded.sql_dml.Insert( + self, values=values, inline=inline, **kwargs + ) - return dml.Insert(self, values=values, inline=inline, **kwargs) - - @util.dependencies("sqlalchemy.sql.dml") - def update( - self, dml, whereclause=None, values=None, inline=False, **kwargs - ): + @util.preload_module("sqlalchemy.sql.dml") + def update(self, whereclause=None, values=None, inline=False, **kwargs): """Generate an :func:`.update` construct against this :class:`.TableClause`. @@ -1958,8 +1959,7 @@ class TableClause(Immutable, FromClause): See :func:`.update` for argument and usage information. """ - - return dml.Update( + return util.preloaded.sql_dml.Update( self, whereclause=whereclause, values=values, @@ -1967,8 +1967,8 @@ class TableClause(Immutable, FromClause): **kwargs ) - @util.dependencies("sqlalchemy.sql.dml") - def delete(self, dml, whereclause=None, **kwargs): + @util.preload_module("sqlalchemy.sql.dml") + def delete(self, whereclause=None, **kwargs): """Generate a :func:`.delete` construct against this :class:`.TableClause`. @@ -1979,8 +1979,7 @@ class TableClause(Immutable, FromClause): See :func:`.delete` for argument and usage information. """ - - return dml.Delete(self, whereclause, **kwargs) + return util.preloaded.sql_dml.Delete(self, whereclause, **kwargs) @property def _from_objects(self): @@ -3864,8 +3863,8 @@ class Select( """ return self.add_columns(column) - @util.dependencies("sqlalchemy.sql.util") - def reduce_columns(self, sqlutil, only_synonyms=True): + @util.preload_module("sqlalchemy.sql.util") + def reduce_columns(self, only_synonyms=True): """Return a new :func`.select` construct with redundantly named, equivalently-valued columns removed from the columns clause. @@ -3887,7 +3886,7 @@ class Select( """ return self.with_only_columns( - sqlutil.reduce_columns( + util.preloaded.sql_util.reduce_columns( self.inner_columns, only_synonyms=only_synonyms, *(self._whereclause,) + tuple(self._from_obj) diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 2d6b44299..3d69d1177 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -1539,8 +1539,9 @@ class Enum(Emulated, String, SchemaType): not self.native_enum or not compiler.dialect.supports_native_enum ) - @util.dependencies("sqlalchemy.sql.schema") - def _set_table(self, schema, column, table): + @util.preload_module("sqlalchemy.sql.schema") + def _set_table(self, column, table): + schema = util.preloaded.sql_schema SchemaType._set_table(self, column, table) if not self.create_constraint: @@ -1738,8 +1739,9 @@ class Boolean(Emulated, TypeEngine, SchemaType): and compiler.dialect.non_native_boolean_check_constraint ) - @util.dependencies("sqlalchemy.sql.schema") - def _set_table(self, schema, column, table): + @util.preload_module("sqlalchemy.sql.schema") + def _set_table(self, column, table): + schema = util.preloaded.sql_schema if not self.create_constraint: return @@ -2228,8 +2230,7 @@ class JSON(Indexable, TypeEngine): class Comparator(Indexable.Comparator, Concatenable.Comparator): """Define comparison operations for :class:`.types.JSON`.""" - @util.dependencies("sqlalchemy.sql.default_comparator") - def _setup_getitem(self, default_comparator, index): + def _setup_getitem(self, index): if not isinstance(index, util.string_types) and isinstance( index, compat.collections_abc.Sequence ): @@ -2553,8 +2554,8 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): "ARRAY type; please use the dialect-specific ARRAY type" ) - @util.dependencies("sqlalchemy.sql.elements") - def any(self, elements, other, operator=None): + @util.preload_module("sqlalchemy.sql.elements") + def any(self, other, operator=None): """Return ``other operator ANY (array)`` clause. Argument places are switched, because ANY requires array @@ -2582,14 +2583,15 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): :meth:`.types.ARRAY.Comparator.all` """ + elements = util.preloaded.sql_elements operator = operator if operator else operators.eq return operator( coercions.expect(roles.ExpressionElementRole, other), elements.CollectionAggregate._create_any(self.expr), ) - @util.dependencies("sqlalchemy.sql.elements") - def all(self, elements, other, operator=None): + @util.preload_module("sqlalchemy.sql.elements") + def all(self, other, operator=None): """Return ``other operator ALL (array)`` clause. Argument places are switched, because ALL requires array @@ -2617,6 +2619,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine): :meth:`.types.ARRAY.Comparator.any` """ + elements = util.preloaded.sql_elements operator = operator if operator else operators.eq return operator( coercions.expect(roles.ExpressionElementRole, other), diff --git a/lib/sqlalchemy/sql/traversals.py b/lib/sqlalchemy/sql/traversals.py index c29a04ee0..2a4f7ebb6 100644 --- a/lib/sqlalchemy/sql/traversals.py +++ b/lib/sqlalchemy/sql/traversals.py @@ -597,9 +597,9 @@ class _GetChildren(InternalTraversal): _get_children = _GetChildren() -@util.dependencies("sqlalchemy.sql.elements") -def _resolve_name_for_compare(elements, element, name, anon_map, **kw): - if isinstance(name, elements._anonymous_label): +@util.preload_module("sqlalchemy.sql.elements") +def _resolve_name_for_compare(element, name, anon_map, **kw): + if isinstance(name, util.preloaded.sql_elements._anonymous_label): name = name.apply_map(anon_map) return name diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py index 739f96195..38189ec9d 100644 --- a/lib/sqlalchemy/sql/type_api.py +++ b/lib/sqlalchemy/sql/type_api.py @@ -64,13 +64,15 @@ class TypeEngine(Traversible): self.expr = expr self.type = expr.type - @util.dependencies("sqlalchemy.sql.default_comparator") - def operate(self, default_comparator, op, *other, **kwargs): + @util.preload_module("sqlalchemy.sql.default_comparator") + def operate(self, op, *other, **kwargs): + default_comparator = util.preloaded.sql_default_comparator o = default_comparator.operator_lookup[op.__name__] return o[0](self.expr, op, *(other + o[1:]), **kwargs) - @util.dependencies("sqlalchemy.sql.default_comparator") - def reverse_operate(self, default_comparator, op, other, **kwargs): + @util.preload_module("sqlalchemy.sql.default_comparator") + def reverse_operate(self, op, other, **kwargs): + default_comparator = util.preloaded.sql_default_comparator o = default_comparator.operator_lookup[op.__name__] return o[0](self.expr, op, other, reverse=True, *o[1:], **kwargs) @@ -613,8 +615,9 @@ class TypeEngine(Traversible): return dialect.type_compiler.process(self) - @util.dependencies("sqlalchemy.engine.default") - def _default_dialect(self, default): + @util.preload_module("sqlalchemy.engine.default") + def _default_dialect(self): + default = util.preloaded.engine_default if self.__class__.__module__.startswith("sqlalchemy.dialects"): tokens = self.__class__.__module__.split(".")[0:3] mod = ".".join(tokens) |
