diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-25 17:08:48 -0400 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2022-03-30 14:04:52 -0400 |
| commit | 4e754a8914a1c2c16c97bdf363d2e24bfa823730 (patch) | |
| tree | db723242b4e4c0d4c7f15c167857dd79fdfa6ccb /lib/sqlalchemy/sql/base.py | |
| parent | dba480ebaf89c0b5ea787661583de9da3928920f (diff) | |
| download | sqlalchemy-4e754a8914a1c2c16c97bdf363d2e24bfa823730.tar.gz | |
pep-484: the pep-484ening, SQL part three
hitting DML which is causing us to open up the
ColumnCollection structure a bit, as we do put anonymous
column expressions with None here. However, we still want
Table /TableClause to have named column collections that
don't return None, so parametrize the "key" in this
collection also.
* rename some "immutable" elements to "readonly". we change
the contents of immutablecolumncollection underneath, so it's
not "immutable"
Change-Id: I2593995a4e5c6eae874bed5bf76117198be8ae97
Diffstat (limited to 'lib/sqlalchemy/sql/base.py')
| -rw-r--r-- | lib/sqlalchemy/sql/base.py | 110 |
1 files changed, 67 insertions, 43 deletions
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 8f5135915..19e4c13d2 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -62,10 +62,14 @@ if TYPE_CHECKING: from . import coercions from . import elements from . import type_api + from ._typing import _ColumnsClauseArgument + from ._typing import _SelectIterable from .elements import BindParameter from .elements import ColumnClause from .elements import ColumnElement + from .elements import NamedColumn from .elements import SQLCoreOperations + from .selectable import FromClause from ..engine import Connection from ..engine import Result from ..engine.base import _CompiledCacheType @@ -91,6 +95,8 @@ class _NoArg(Enum): NO_ARG = _NoArg.NO_ARG +_T = TypeVar("_T", bound=Any) + _Fn = TypeVar("_Fn", bound=Callable[..., Any]) _AmbiguousTableNameMap = MutableMapping[str, str] @@ -102,7 +108,9 @@ class _EntityNamespace(Protocol): class _HasEntityNamespace(Protocol): - entity_namespace: _EntityNamespace + @util.ro_non_memoized_property + def entity_namespace(self) -> _EntityNamespace: + ... def _is_has_entity_namespace(element: Any) -> TypeGuard[_HasEntityNamespace]: @@ -136,8 +144,8 @@ class SingletonConstant(Immutable): _singleton: SingletonConstant - def __new__(cls, *arg, **kw): - return cls._singleton + def __new__(cls: _T, *arg: Any, **kw: Any) -> _T: + return cast(_T, cls._singleton) @util.non_memoized_property def proxy_set(self) -> FrozenSet[ColumnElement[Any]]: @@ -159,13 +167,15 @@ class SingletonConstant(Immutable): cls._singleton = obj -def _from_objects(*elements): +def _from_objects(*elements: ColumnElement[Any]) -> Iterator[FromClause]: return itertools.chain.from_iterable( [element._from_objects for element in elements] ) -def _select_iterables(elements): +def _select_iterables( + elements: Iterable[roles.ColumnsClauseRole], +) -> _SelectIterable: """expand tables into individual columns in the given list of column expressions. @@ -207,7 +217,7 @@ def _generative(fn: _Fn) -> _Fn: return decorated # type: ignore -def _exclusive_against(*names, **kw): +def _exclusive_against(*names: str, **kw: Any) -> Callable[[_Fn], _Fn]: msgs = kw.pop("msgs", {}) defaults = kw.pop("defaults", {}) @@ -502,7 +512,7 @@ class DialectKWArgs: util.portable_instancemethod(self._kw_reg_for_dialect_cls) ) - def _validate_dialect_kwargs(self, kwargs): + def _validate_dialect_kwargs(self, kwargs: Any) -> None: # validate remaining kwargs that they all specify DB prefixes if not kwargs: @@ -605,7 +615,9 @@ class CompileState: self.statement = statement @classmethod - def get_plugin_class(cls, statement): + def get_plugin_class( + cls, statement: Executable + ) -> Optional[Type[CompileState]]: plugin_name = statement._propagate_attrs.get( "compile_state_plugin", None ) @@ -634,7 +646,9 @@ class CompileState: return None @classmethod - def plugin_for(cls, plugin_name, visit_name): + def plugin_for( + cls, plugin_name: str, visit_name: str + ) -> Callable[[_Fn], _Fn]: def decorate(cls_to_decorate): cls.plugins[(plugin_name, visit_name)] = cls_to_decorate return cls_to_decorate @@ -957,7 +971,7 @@ class Executable(roles.StatementRole, Generative): ) -> Result: ... - @property + @util.non_memoized_property def _all_selected_columns(self): raise NotImplementedError() @@ -1202,10 +1216,11 @@ class SchemaVisitor(ClauseVisitor): __traverse_options__ = {"schema_visitor": True} -_COL = TypeVar("_COL", bound="ColumnClause[Any]") +_COLKEY = TypeVar("_COLKEY", Union[None, str], str) +_COL = TypeVar("_COL", bound="ColumnElement[Any]") -class ColumnCollection(Generic[_COL]): +class ColumnCollection(Generic[_COLKEY, _COL]): """Collection of :class:`_expression.ColumnElement` instances, typically for :class:`_sql.FromClause` objects. @@ -1316,25 +1331,27 @@ class ColumnCollection(Generic[_COL]): __slots__ = "_collection", "_index", "_colset" - _collection: List[Tuple[str, _COL]] - _index: Dict[Union[str, int], _COL] + _collection: List[Tuple[_COLKEY, _COL]] + _index: Dict[Union[None, str, int], _COL] _colset: Set[_COL] - def __init__(self, columns: Optional[Iterable[Tuple[str, _COL]]] = None): + def __init__( + self, columns: Optional[Iterable[Tuple[_COLKEY, _COL]]] = None + ): object.__setattr__(self, "_colset", set()) object.__setattr__(self, "_index", {}) object.__setattr__(self, "_collection", []) if columns: self._initial_populate(columns) - def _initial_populate(self, iter_: Iterable[Tuple[str, _COL]]) -> None: + def _initial_populate(self, iter_: Iterable[Tuple[_COLKEY, _COL]]) -> None: self._populate_separate_keys(iter_) @property def _all_columns(self) -> List[_COL]: return [col for (k, col) in self._collection] - def keys(self) -> List[str]: + def keys(self) -> List[_COLKEY]: """Return a sequence of string key names for all columns in this collection.""" return [k for (k, col) in self._collection] @@ -1345,7 +1362,7 @@ class ColumnCollection(Generic[_COL]): collection.""" return [col for (k, col) in self._collection] - def items(self) -> List[Tuple[str, _COL]]: + def items(self) -> List[Tuple[_COLKEY, _COL]]: """Return a sequence of (key, column) tuples for all columns in this collection each consisting of a string key name and a :class:`_sql.ColumnClause` or @@ -1389,7 +1406,7 @@ class ColumnCollection(Generic[_COL]): else: return True - def compare(self, other: ColumnCollection[Any]) -> bool: + def compare(self, other: ColumnCollection[Any, Any]) -> bool: """Compare this :class:`_expression.ColumnCollection` to another based on the names of the keys""" @@ -1444,7 +1461,7 @@ class ColumnCollection(Generic[_COL]): __hash__ = None # type: ignore def _populate_separate_keys( - self, iter_: Iterable[Tuple[str, _COL]] + self, iter_: Iterable[Tuple[_COLKEY, _COL]] ) -> None: """populate from an iterator of (key, column)""" cols = list(iter_) @@ -1455,7 +1472,7 @@ class ColumnCollection(Generic[_COL]): ) self._index.update({k: col for k, col in reversed(self._collection)}) - def add(self, column: _COL, key: Optional[str] = None) -> None: + def add(self, column: _COL, key: Optional[_COLKEY] = None) -> None: """Add a column to this :class:`_sql.ColumnCollection`. .. note:: @@ -1467,15 +1484,19 @@ class ColumnCollection(Generic[_COL]): object, use the :meth:`_schema.Table.append_column` method. """ + colkey: _COLKEY + if key is None: - key = column.key + colkey = column.key # type: ignore + else: + colkey = key l = len(self._collection) - self._collection.append((key, column)) + self._collection.append((colkey, column)) self._colset.add(column) self._index[l] = column - if key not in self._index: - self._index[key] = column + if colkey not in self._index: + self._index[colkey] = column def __getstate__(self) -> Dict[str, Any]: return {"_collection": self._collection, "_index": self._index} @@ -1499,11 +1520,11 @@ class ColumnCollection(Generic[_COL]): else: return True - def as_immutable(self) -> ImmutableColumnCollection[_COL]: - """Return an "immutable" form of this + def as_readonly(self) -> ReadOnlyColumnCollection[_COLKEY, _COL]: + """Return a "read only" form of this :class:`_sql.ColumnCollection`.""" - return ImmutableColumnCollection(self) + return ReadOnlyColumnCollection(self) def corresponding_column( self, column: _COL, require_embedded: bool = False @@ -1605,7 +1626,10 @@ class ColumnCollection(Generic[_COL]): return col -class DedupeColumnCollection(ColumnCollection[_COL]): +_NAMEDCOL = TypeVar("_NAMEDCOL", bound="NamedColumn[Any]") + + +class DedupeColumnCollection(ColumnCollection[str, _NAMEDCOL]): """A :class:`_expression.ColumnCollection` that maintains deduplicating behavior. @@ -1618,7 +1642,7 @@ class DedupeColumnCollection(ColumnCollection[_COL]): """ - def add(self, column: _COL, key: Optional[str] = None) -> None: + def add(self, column: _NAMEDCOL, key: Optional[str] = None) -> None: if key is not None and column.key != key: raise exc.ArgumentError( @@ -1653,7 +1677,7 @@ class DedupeColumnCollection(ColumnCollection[_COL]): self._index[key] = column def _populate_separate_keys( - self, iter_: Iterable[Tuple[str, _COL]] + self, iter_: Iterable[Tuple[str, _NAMEDCOL]] ) -> None: """populate from an iterator of (key, column)""" cols = list(iter_) @@ -1679,10 +1703,10 @@ class DedupeColumnCollection(ColumnCollection[_COL]): for col in replace_col: self.replace(col) - def extend(self, iter_: Iterable[_COL]) -> None: - self._populate_separate_keys((col.key, col) for col in iter_) + def extend(self, iter_: Iterable[_NAMEDCOL]) -> None: + self._populate_separate_keys((col.key, col) for col in iter_) # type: ignore # noqa: E501 - def remove(self, column: _COL) -> None: + def remove(self, column: _NAMEDCOL) -> None: if column not in self._colset: raise ValueError( "Can't remove column %r; column is not in this collection" @@ -1699,7 +1723,7 @@ class DedupeColumnCollection(ColumnCollection[_COL]): # delete higher index del self._index[len(self._collection)] - def replace(self, column: _COL) -> None: + def replace(self, column: _NAMEDCOL) -> None: """add the given column to this collection, removing unaliased versions of this column as well as existing columns with the same key. @@ -1726,7 +1750,7 @@ class DedupeColumnCollection(ColumnCollection[_COL]): if column.key in self._index: remove_col.add(self._index[column.key]) - new_cols = [] + new_cols: List[Tuple[str, _NAMEDCOL]] = [] replaced = False for k, col in self._collection: if col in remove_col: @@ -1752,8 +1776,8 @@ class DedupeColumnCollection(ColumnCollection[_COL]): self._index.update(self._collection) -class ImmutableColumnCollection( - util.ImmutableContainer, ColumnCollection[_COL] +class ReadOnlyColumnCollection( + util.ReadOnlyContainer, ColumnCollection[_COLKEY, _COL] ): __slots__ = ("_parent",) @@ -1771,13 +1795,13 @@ class ImmutableColumnCollection( self.__init__(parent) # type: ignore def add(self, column: Any, key: Any = ...) -> Any: - self._immutable() + self._readonly() - def extend(self, elements: Any) -> None: - self._immutable() + def extend(self, elements: Any) -> NoReturn: + self._readonly() - def remove(self, item: Any) -> None: - self._immutable() + def remove(self, item: Any) -> NoReturn: + self._readonly() class ColumnSet(util.OrderedSet["ColumnClause[Any]"]): |
