diff options
| author | J. Nick Koston <nick@koston.org> | 2023-04-19 18:39:18 -0400 |
|---|---|---|
| committer | Federico Caselli <cfederico87@gmail.com> | 2023-04-26 20:19:17 +0200 |
| commit | ff198e35f0e04b8d38df25df234e72259069b4d1 (patch) | |
| tree | c48db9a0366b48c8caaa35ad9ab83a354aaa7d32 /lib/sqlalchemy/engine | |
| parent | 9f675fd042b05977f1b38887c2fbbb54ecd424f7 (diff) | |
| download | sqlalchemy-ff198e35f0e04b8d38df25df234e72259069b4d1.tar.gz | |
Prebuild the row string to position lookup for Rows
Improved :class:`_engine.Row` implementation to optimize
``__getattr__`` performance.
The serialization of a :class:`_engine.Row` to pickle has changed with
this change. Pickle saved by older SQLAlchemy versions can still be loaded,
but new pickle saved by this version cannot be loaded by older ones.
Fixes: #9678
Closes: #9668
Pull-request: https://github.com/sqlalchemy/sqlalchemy/pull/9668
Pull-request-sha: 86b8ccd1959dbd91b1208f7a648a91f217e1f866
Change-Id: Ia85c26a59e1a57ba2bf0d65578c6168f82a559f2
Diffstat (limited to 'lib/sqlalchemy/engine')
| -rw-r--r-- | lib/sqlalchemy/engine/_py_row.py | 72 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/cursor.py | 86 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/result.py | 55 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/row.py | 28 |
4 files changed, 103 insertions, 138 deletions
diff --git a/lib/sqlalchemy/engine/_py_row.py b/lib/sqlalchemy/engine/_py_row.py index 1b952fe4c..4a9acec9b 100644 --- a/lib/sqlalchemy/engine/_py_row.py +++ b/lib/sqlalchemy/engine/_py_row.py @@ -1,6 +1,5 @@ from __future__ import annotations -import enum import operator import typing from typing import Any @@ -8,13 +7,12 @@ from typing import Callable from typing import Dict from typing import Iterator from typing import List +from typing import Mapping from typing import Optional from typing import Tuple from typing import Type -from typing import Union if typing.TYPE_CHECKING: - from .result import _KeyMapType from .result import _KeyType from .result import _ProcessorsType from .result import _RawRowType @@ -24,38 +22,25 @@ if typing.TYPE_CHECKING: MD_INDEX = 0 # integer index in cursor.description -class _KeyStyle(enum.IntEnum): - KEY_INTEGER_ONLY = 0 - """__getitem__ only allows integer values and slices, raises TypeError - otherwise""" - - KEY_OBJECTS_ONLY = 1 - """__getitem__ only allows string/object values, raises TypeError - otherwise""" - - -KEY_INTEGER_ONLY, KEY_OBJECTS_ONLY = list(_KeyStyle) - - class BaseRow: - __slots__ = ("_parent", "_data", "_keymap", "_key_style") + __slots__ = ("_parent", "_data", "_key_to_index") _parent: ResultMetaData + _key_to_index: Mapping[_KeyType, int] _data: _RawRowType - _keymap: _KeyMapType - _key_style: _KeyStyle def __init__( self, parent: ResultMetaData, processors: Optional[_ProcessorsType], - keymap: _KeyMapType, - key_style: _KeyStyle, + key_to_index: Mapping[_KeyType, int], data: _RawRowType, ): """Row objects are constructed by CursorResult objects.""" object.__setattr__(self, "_parent", parent) + object.__setattr__(self, "_key_to_index", key_to_index) + if processors: object.__setattr__( self, @@ -70,10 +55,6 @@ class BaseRow: else: object.__setattr__(self, "_data", tuple(data)) - object.__setattr__(self, "_keymap", keymap) - - object.__setattr__(self, "_key_style", key_style) - def __reduce__(self) -> Tuple[Callable[..., BaseRow], Tuple[Any, ...]]: return ( rowproxy_reconstructor, @@ -81,18 +62,13 @@ class BaseRow: ) def __getstate__(self) -> Dict[str, Any]: - return { - "_parent": self._parent, - "_data": self._data, - "_key_style": self._key_style, - } + return {"_parent": self._parent, "_data": self._data} def __setstate__(self, state: Dict[str, Any]) -> None: parent = state["_parent"] object.__setattr__(self, "_parent", parent) object.__setattr__(self, "_data", state["_data"]) - object.__setattr__(self, "_keymap", parent._keymap) - object.__setattr__(self, "_key_style", state["_key_style"]) + object.__setattr__(self, "_key_to_index", parent._key_to_index) def _values_impl(self) -> List[Any]: return list(self) @@ -106,34 +82,22 @@ class BaseRow: def __hash__(self) -> int: return hash(self._data) - def _get_by_int_impl(self, key: Union[int, slice]) -> Any: + def __getitem__(self, key: Any) -> Any: return self._data[key] - if not typing.TYPE_CHECKING: - __getitem__ = _get_by_int_impl - - def _get_by_key_impl_mapping(self, key: _KeyType) -> Any: + def _get_by_key_impl_mapping(self, key: str) -> Any: try: - rec = self._keymap[key] - except KeyError as ke: - rec = self._parent._key_fallback(key, ke) - - mdindex = rec[MD_INDEX] - if mdindex is None: - self._parent._raise_for_ambiguous_column_name(rec) - # NOTE: keep "== KEY_OBJECTS_ONLY" instead of "is KEY_OBJECTS_ONLY" - # since deserializing the class from cython will load an int in - # _key_style, not an instance of _KeyStyle - elif self._key_style == KEY_OBJECTS_ONLY and isinstance(key, int): - raise KeyError(key) - - return self._data[mdindex] + return self._data[self._key_to_index[key]] + except KeyError: + pass + self._parent._key_not_found(key, False) def __getattr__(self, name: str) -> Any: try: - return self._get_by_key_impl_mapping(name) - except KeyError as e: - raise AttributeError(e.args[0]) from e + return self._data[self._key_to_index[name]] + except KeyError: + pass + self._parent._key_not_found(name, True) # This reconstructor is necessary so that pickles with the Cy extension or diff --git a/lib/sqlalchemy/engine/cursor.py b/lib/sqlalchemy/engine/cursor.py index aaf2c1918..bd46f30ac 100644 --- a/lib/sqlalchemy/engine/cursor.py +++ b/lib/sqlalchemy/engine/cursor.py @@ -21,9 +21,9 @@ from typing import ClassVar from typing import Dict from typing import Iterator from typing import List +from typing import Mapping from typing import NoReturn from typing import Optional -from typing import overload from typing import Sequence from typing import Tuple from typing import TYPE_CHECKING @@ -123,7 +123,7 @@ _CursorKeyMapRecType = Tuple[ Optional[str], # MD_UNTRANSLATED ] -_CursorKeyMapType = Dict["_KeyType", _CursorKeyMapRecType] +_CursorKeyMapType = Mapping["_KeyType", _CursorKeyMapRecType] # same as _CursorKeyMapRecType except the MD_INDEX value is definitely # not None @@ -149,7 +149,8 @@ class CursorResultMetaData(ResultMetaData): "_tuplefilter", "_translated_indexes", "_safe_for_cache", - "_unpickled" + "_unpickled", + "_key_to_index" # don't need _unique_filters support here for now. Can be added # if a need arises. ) @@ -193,6 +194,7 @@ class CursorResultMetaData(ResultMetaData): new_obj._translated_indexes = translated_indexes new_obj._safe_for_cache = safe_for_cache new_obj._keymap_by_result_column_idx = keymap_by_result_column_idx + new_obj._key_to_index = self._make_key_to_index(keymap, MD_INDEX) return new_obj def _remove_processors(self) -> CursorResultMetaData: @@ -217,7 +219,7 @@ class CursorResultMetaData(ResultMetaData): assert not self._tuplefilter - keymap = self._keymap.copy() + keymap = dict(self._keymap) offset = len(self._keys) keymap.update( { @@ -232,7 +234,6 @@ class CursorResultMetaData(ResultMetaData): for key, value in other._keymap.items() } ) - return self._make_new_metadata( unpickled=self._unpickled, processors=self._processors + other._processors, # type: ignore @@ -258,7 +259,7 @@ class CursorResultMetaData(ResultMetaData): tup = tuplegetter(*indexes) new_recs = [(index,) + rec[1:] for index, rec in enumerate(recs)] - keymap: _KeyMapType = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs} + keymap = {rec[MD_LOOKUP_KEY]: rec for rec in new_recs} # TODO: need unit test for: # result = connection.execute("raw sql, no columns").scalars() # without the "or ()" it's failing because MD_OBJECTS is None @@ -274,7 +275,7 @@ class CursorResultMetaData(ResultMetaData): keys=new_keys, tuplefilter=tup, translated_indexes=indexes, - keymap=keymap, + keymap=keymap, # type: ignore[arg-type] safe_for_cache=self._safe_for_cache, keymap_by_result_column_idx=self._keymap_by_result_column_idx, ) @@ -491,6 +492,8 @@ class CursorResultMetaData(ResultMetaData): } ) + self._key_to_index = self._make_key_to_index(self._keymap, MD_INDEX) + def _merge_cursor_description( self, context, @@ -807,41 +810,25 @@ class CursorResultMetaData(ResultMetaData): untranslated, ) - @overload - def _key_fallback( - self, key: Any, err: Exception, raiseerr: Literal[True] = ... - ) -> NoReturn: - ... + if not TYPE_CHECKING: - @overload - def _key_fallback( - self, key: Any, err: Exception, raiseerr: Literal[False] = ... - ) -> None: - ... - - @overload - def _key_fallback( - self, key: Any, err: Exception, raiseerr: bool = ... - ) -> Optional[NoReturn]: - ... - - def _key_fallback( - self, key: Any, err: Exception, raiseerr: bool = True - ) -> Optional[NoReturn]: - - if raiseerr: - if self._unpickled and isinstance(key, elements.ColumnElement): - raise exc.NoSuchColumnError( - "Row was unpickled; lookup by ColumnElement " - "is unsupported" - ) from err + def _key_fallback( + self, key: Any, err: Optional[Exception], raiseerr: bool = True + ) -> Optional[NoReturn]: + + if raiseerr: + if self._unpickled and isinstance(key, elements.ColumnElement): + raise exc.NoSuchColumnError( + "Row was unpickled; lookup by ColumnElement " + "is unsupported" + ) from err + else: + raise exc.NoSuchColumnError( + "Could not locate column in row for column '%s'" + % util.string_or_unprintable(key) + ) from err else: - raise exc.NoSuchColumnError( - "Could not locate column in row for column '%s'" - % util.string_or_unprintable(key) - ) from err - else: - return None + return None def _raise_for_ambiguous_column_name(self, rec): raise exc.InvalidRequestError( @@ -919,8 +906,8 @@ class CursorResultMetaData(ResultMetaData): def __setstate__(self, state): self._processors = [None for _ in range(len(state["_keys"]))] self._keymap = state["_keymap"] - self._keymap_by_result_column_idx = None + self._key_to_index = self._make_key_to_index(self._keymap, MD_INDEX) self._keys = state["_keys"] self._unpickled = True if state["_translated_indexes"]: @@ -1371,6 +1358,14 @@ class _NoResultMetaData(ResultMetaData): self._we_dont_return_rows() @property + def _key_to_index(self): + self._we_dont_return_rows() + + @property + def _processors(self): + self._we_dont_return_rows() + + @property def keys(self): self._we_dont_return_rows() @@ -1458,12 +1453,11 @@ class CursorResult(Result[_T]): metadata = self._init_metadata(context, cursor_description) - keymap = metadata._keymap - processors = metadata._processors - process_row = Row - key_style = process_row._default_key_style _make_row = functools.partial( - process_row, metadata, processors, keymap, key_style + Row, + metadata, + metadata._processors, + metadata._key_to_index, ) if log_row: diff --git a/lib/sqlalchemy/engine/result.py b/lib/sqlalchemy/engine/result.py index d5b8057ef..cc6d26c88 100644 --- a/lib/sqlalchemy/engine/result.py +++ b/lib/sqlalchemy/engine/result.py @@ -22,6 +22,7 @@ from typing import Generic from typing import Iterable from typing import Iterator from typing import List +from typing import Mapping from typing import NoReturn from typing import Optional from typing import overload @@ -59,7 +60,7 @@ _KeyIndexType = Union[str, "Column[Any]", int] # is overridden in cursor using _CursorKeyMapRecType _KeyMapRecType = Any -_KeyMapType = Dict[_KeyType, _KeyMapRecType] +_KeyMapType = Mapping[_KeyType, _KeyMapRecType] _RowData = Union[Row, RowMapping, Any] @@ -99,6 +100,7 @@ class ResultMetaData: _keymap: _KeyMapType _keys: Sequence[str] _processors: Optional[_ProcessorsType] + _key_to_index: Mapping[_KeyType, int] @property def keys(self) -> RMKeyView: @@ -112,24 +114,27 @@ class ResultMetaData: @overload def _key_fallback( - self, key: Any, err: Exception, raiseerr: Literal[True] = ... + self, key: Any, err: Optional[Exception], raiseerr: Literal[True] = ... ) -> NoReturn: ... @overload def _key_fallback( - self, key: Any, err: Exception, raiseerr: Literal[False] = ... + self, + key: Any, + err: Optional[Exception], + raiseerr: Literal[False] = ..., ) -> None: ... @overload def _key_fallback( - self, key: Any, err: Exception, raiseerr: bool = ... + self, key: Any, err: Optional[Exception], raiseerr: bool = ... ) -> Optional[NoReturn]: ... def _key_fallback( - self, key: Any, err: Exception, raiseerr: bool = True + self, key: Any, err: Optional[Exception], raiseerr: bool = True ) -> Optional[NoReturn]: assert raiseerr raise KeyError(key) from err @@ -177,6 +182,29 @@ class ResultMetaData: indexes = self._indexes_for_keys(keys) return tuplegetter(*indexes) + def _make_key_to_index( + self, keymap: Mapping[_KeyType, Sequence[Any]], index: int + ) -> Mapping[_KeyType, int]: + return { + key: rec[index] + for key, rec in keymap.items() + if rec[index] is not None + } + + def _key_not_found(self, key: Any, attr_error: bool) -> NoReturn: + if key in self._keymap: + # the index must be none in this case + self._raise_for_ambiguous_column_name(self._keymap[key]) + else: + # unknown key + if attr_error: + try: + self._key_fallback(key, None) + except KeyError as ke: + raise AttributeError(ke.args[0]) from ke + else: + self._key_fallback(key, None) + class RMKeyView(typing.KeysView[Any]): __slots__ = ("_parent", "_keys") @@ -222,6 +250,7 @@ class SimpleResultMetaData(ResultMetaData): "_tuplefilter", "_translated_indexes", "_unique_filters", + "_key_to_index", ) _keys: Sequence[str] @@ -257,6 +286,8 @@ class SimpleResultMetaData(ResultMetaData): self._processors = _processors + self._key_to_index = self._make_key_to_index(self._keymap, 0) + def _has_key(self, key: object) -> bool: return key in self._keymap @@ -359,7 +390,7 @@ def result_tuple( ) -> Callable[[Iterable[Any]], Row[Any]]: parent = SimpleResultMetaData(fields, extra) return functools.partial( - Row, parent, parent._processors, parent._keymap, Row._default_key_style + Row, parent, parent._processors, parent._key_to_index ) @@ -424,21 +455,19 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): def process_row( # type: ignore metadata: ResultMetaData, processors: _ProcessorsType, - keymap: _KeyMapType, - key_style: Any, + key_to_index: Mapping[_KeyType, int], scalar_obj: Any, ) -> Row[Any]: return _proc( - metadata, processors, keymap, key_style, (scalar_obj,) + metadata, processors, key_to_index, (scalar_obj,) ) else: process_row = Row # type: ignore - key_style = Row._default_key_style metadata = self._metadata - keymap = metadata._keymap + key_to_index = metadata._key_to_index processors = metadata._processors tf = metadata._tuplefilter @@ -447,7 +476,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): processors = tf(processors) _make_row_orig: Callable[..., _R] = functools.partial( # type: ignore # noqa E501 - process_row, metadata, processors, keymap, key_style + process_row, metadata, processors, key_to_index ) fixed_tf = tf @@ -457,7 +486,7 @@ class ResultInternal(InPlaceGenerative, Generic[_R]): else: make_row = functools.partial( # type: ignore - process_row, metadata, processors, keymap, key_style + process_row, metadata, processors, key_to_index ) fns: Tuple[Any, ...] = () diff --git a/lib/sqlalchemy/engine/row.py b/lib/sqlalchemy/engine/row.py index e15ea7b17..4b767da09 100644 --- a/lib/sqlalchemy/engine/row.py +++ b/lib/sqlalchemy/engine/row.py @@ -34,12 +34,8 @@ from ..util._has_cy import HAS_CYEXTENSION if TYPE_CHECKING or not HAS_CYEXTENSION: from ._py_row import BaseRow as BaseRow - from ._py_row import KEY_INTEGER_ONLY - from ._py_row import KEY_OBJECTS_ONLY else: from sqlalchemy.cyextension.resultproxy import BaseRow as BaseRow - from sqlalchemy.cyextension.resultproxy import KEY_INTEGER_ONLY - from sqlalchemy.cyextension.resultproxy import KEY_OBJECTS_ONLY if TYPE_CHECKING: from .result import _KeyType @@ -80,8 +76,6 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]): __slots__ = () - _default_key_style = KEY_INTEGER_ONLY - def __setattr__(self, name: str, value: Any) -> NoReturn: raise AttributeError("can't set attribute") @@ -134,24 +128,12 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]): .. versionadded:: 1.4 """ - return RowMapping( - self._parent, - None, - self._keymap, - RowMapping._default_key_style, - self._data, - ) + return RowMapping(self._parent, None, self._key_to_index, self._data) def _filter_on_values( self, filters: Optional[Sequence[Optional[_ResultProcessorType[Any]]]] ) -> Row[Any]: - return Row( - self._parent, - filters, - self._keymap, - self._key_style, - self._data, - ) + return Row(self._parent, filters, self._key_to_index, self._data) if not TYPE_CHECKING: @@ -198,9 +180,7 @@ class Row(BaseRow, Sequence[Any], Generic[_TP]): def __getitem__(self, index: slice) -> Sequence[Any]: ... - def __getitem__( - self, index: Union[int, slice] - ) -> Union[Any, Sequence[Any]]: + def __getitem__(self, index: Union[int, slice]) -> Any: ... def __lt__(self, other: Any) -> bool: @@ -337,8 +317,6 @@ class RowMapping(BaseRow, typing.Mapping["_KeyType", Any]): __slots__ = () - _default_key_style = KEY_OBJECTS_ONLY - if TYPE_CHECKING: def __getitem__(self, key: _KeyType) -> Any: |
