summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/orm/interfaces.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/orm/interfaces.py')
-rw-r--r--lib/sqlalchemy/orm/interfaces.py423
1 files changed, 260 insertions, 163 deletions
diff --git a/lib/sqlalchemy/orm/interfaces.py b/lib/sqlalchemy/orm/interfaces.py
index d61ebe960..6c9fe7753 100644
--- a/lib/sqlalchemy/orm/interfaces.py
+++ b/lib/sqlalchemy/orm/interfaces.py
@@ -4,27 +4,45 @@
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
-"""Semi-private implementation objects which form the basis
-of ORM-mapped attributes, query options and mapper extension.
+"""
+
+Semi-private implementation objects which form the basis of ORM-mapped
+attributes, query options and mapper extension.
+
+Defines the [sqlalchemy.orm.interfaces#MapperExtension] class, which can be
+end-user subclassed to add event-based functionality to mappers. The
+remainder of this module is generally private to the ORM.
-Defines the [sqlalchemy.orm.interfaces#MapperExtension] class,
-which can be end-user subclassed to add event-based functionality
-to mappers. The remainder of this module is generally private to the
-ORM.
"""
from itertools import chain
-from sqlalchemy import exceptions, logging, util
+
+import sqlalchemy.exceptions as sa_exc
+from sqlalchemy import log, util
from sqlalchemy.sql import expression
-class_mapper = None
-__all__ = ['EXT_CONTINUE', 'EXT_STOP', 'EXT_PASS', 'MapperExtension',
- 'MapperProperty', 'PropComparator', 'StrategizedProperty',
- 'build_path', 'MapperOption',
- 'ExtensionOption', 'PropertyOption',
- 'AttributeExtension', 'StrategizedOption', 'LoaderStrategy' ]
+class_mapper = None
+collections = None
+
+__all__ = (
+ 'AttributeExtension',
+ 'EXT_CONTINUE',
+ 'EXT_STOP',
+ 'ExtensionOption',
+ 'InstrumentationManager',
+ 'LoaderStrategy',
+ 'MapperExtension',
+ 'MapperOption',
+ 'MapperProperty',
+ 'PropComparator',
+ 'PropertyOption',
+ 'SessionExtension',
+ 'StrategizedOption',
+ 'StrategizedProperty',
+ 'build_path',
+ )
-EXT_CONTINUE = EXT_PASS = util.symbol('EXT_CONTINUE')
+EXT_CONTINUE = util.symbol('EXT_CONTINUE')
EXT_STOP = util.symbol('EXT_STOP')
ONETOMANY = util.symbol('ONETOMANY')
@@ -44,10 +62,7 @@ class MapperExtension(object):
these exception cases, any return value other than EXT_CONTINUE or
EXT_STOP will be interpreted as equivalent to EXT_STOP.
- EXT_PASS is a synonym for EXT_CONTINUE and is provided for backward
- compatibility.
"""
-
def instrument_class(self, mapper, class_):
return EXT_CONTINUE
@@ -57,16 +72,6 @@ class MapperExtension(object):
def init_failed(self, mapper, class_, oldinit, instance, args, kwargs):
return EXT_CONTINUE
- def get_session(self):
- """Retrieve a contextual Session instance with which to
- register a new object.
-
- Note: this is not called if a session is provided with the
- `__init__` params (i.e. `_sa_session`).
- """
-
- return EXT_CONTINUE
-
def load(self, query, *args, **kwargs):
"""Override the `load` method of the Query object.
@@ -85,43 +90,6 @@ class MapperExtension(object):
return EXT_CONTINUE
- def get_by(self, query, *args, **kwargs):
- """Override the `get_by` method of the Query object.
-
- The return value of this method is used as the result of
- ``query.get_by()`` if the value is anything other than
- EXT_CONTINUE.
-
- DEPRECATED.
- """
-
- return EXT_CONTINUE
-
- def select_by(self, query, *args, **kwargs):
- """Override the `select_by` method of the Query object.
-
- The return value of this method is used as the result of
- ``query.select_by()`` if the value is anything other than
- EXT_CONTINUE.
-
- DEPRECATED.
- """
-
- return EXT_CONTINUE
-
- def select(self, query, *args, **kwargs):
- """Override the `select` method of the Query object.
-
- The return value of this method is used as the result of
- ``query.select()`` if the value is anything other than
- EXT_CONTINUE.
-
- DEPRECATED.
- """
-
- return EXT_CONTINUE
-
-
def translate_row(self, mapper, context, row):
"""Perform pre-processing on the given result row and return a
new row instance.
@@ -276,6 +244,56 @@ class MapperExtension(object):
return EXT_CONTINUE
+class SessionExtension(object):
+ """An extension hook object for Sessions. Subclasses may be installed into a Session
+ (or sessionmaker) using the ``extension`` keyword argument.
+ """
+
+ def before_commit(self, session):
+ """Execute right before commit is called.
+
+ Note that this may not be per-flush if a longer running transaction is ongoing."""
+
+ def after_commit(self, session):
+ """Execute after a commit has occured.
+
+ Note that this may not be per-flush if a longer running transaction is ongoing."""
+
+ def after_rollback(self, session):
+ """Execute after a rollback has occured.
+
+ Note that this may not be per-flush if a longer running transaction is ongoing."""
+
+ def before_flush(self, session, flush_context, instances):
+ """Execute before flush process has started.
+
+ `instances` is an optional list of objects which were passed to the ``flush()``
+ method.
+ """
+
+ def after_flush(self, session, flush_context):
+ """Execute after flush has completed, but before commit has been called.
+
+ Note that the session's state is still in pre-flush, i.e. 'new', 'dirty',
+ and 'deleted' lists still show pre-flush state as well as the history
+ settings on instance attributes."""
+
+ def after_flush_postexec(self, session, flush_context):
+ """Execute after flush has completed, and after the post-exec state occurs.
+
+ This will be when the 'new', 'dirty', and 'deleted' lists are in their final
+ state. An actual commit() may or may not have occured, depending on whether or not
+ the flush started its own transaction or participated in a larger transaction.
+ """
+
+ def after_begin(self, session, transaction, connection):
+ """Execute after a transaction is begun on a connection
+
+ `transaction` is the SessionTransaction. This method is called after an
+ engine level transaction is begun on a connection.
+ """
+
+
class MapperProperty(object):
"""Manage the relationship of a ``Mapper`` to a single class
attribute, as well as that attribute as it appears on individual
@@ -283,7 +301,7 @@ class MapperProperty(object):
attribute access, loading behavior, and dependency calculations.
"""
- def setup(self, querycontext, **kwargs):
+ def setup(self, context, entity, path, adapter, **kwargs):
"""Called by Query for the purposes of constructing a SQL statement.
Each MapperProperty associated with the target mapper processes the
@@ -293,8 +311,8 @@ class MapperProperty(object):
pass
- def create_row_processor(self, selectcontext, mapper, row):
- """Return a 3-tuple consiting of two row processing functions and an instance post-processing function.
+ def create_row_processor(self, selectcontext, path, mapper, row, adapter):
+ """Return a 2-tuple consiting of two row processing functions and an instance post-processing function.
Input arguments are the query.SelectionContext and the *first*
applicable row of a result set obtained within
@@ -305,32 +323,24 @@ class MapperProperty(object):
columns present in the row (which will be the same columns present in
all rows) are used to determine the presence and behavior of the
returned callables. The callables will then be used to process all
- rows and to post-process all instances, respectively.
+ rows and instances.
Callables are of the following form::
- def new_execute(instance, row, **flags):
- # process incoming instance and given row. the instance is
+ def new_execute(state, row, **flags):
+ # process incoming instance state and given row. the instance is
# "new" and was just created upon receipt of this row.
# flags is a dictionary containing at least the following
# attributes:
# isnew - indicates if the instance was newly created as a
# result of reading this row
# instancekey - identity key of the instance
- # optional attribute:
- # ispostselect - indicates if this row resulted from a
- # 'post' select of additional tables/columns
- def existing_execute(instance, row, **flags):
- # process incoming instance and given row. the instance is
+ def existing_execute(state, row, **flags):
+ # process incoming instance state and given row. the instance is
# "existing" and was created based on a previous row.
- def post_execute(instance, **flags):
- # process instance after all result rows have been processed.
- # this function should be used to issue additional selections
- # in order to eagerly load additional properties.
-
- return (new_execute, existing_execute, post_execute)
+ return (new_execute, existing_execute)
Either of the three tuples can be ``None`` in which case no function
is called.
@@ -347,20 +357,6 @@ class MapperProperty(object):
return iter([])
- def get_criterion(self, query, key, value):
- """Return a ``WHERE`` clause suitable for this
- ``MapperProperty`` corresponding to the given key/value pair,
- where the key is a column or object property name, and value
- is a value to be matched. This is only picked up by
- ``PropertyLoaders``.
-
- This is called by a ``Query``'s ``join_by`` method to formulate a set
- of key/value pairs into a ``WHERE`` criterion that spans multiple
- tables if needed.
- """
-
- return None
-
def set_parent(self, parent):
self.parent = parent
@@ -427,10 +423,10 @@ class PropComparator(expression.ColumnOperators):
which returns the MapperProperty associated with this
PropComparator.
"""
-
- def expression_element(self):
- return self.clause_element()
-
+
+ def __clause_element__(self):
+ raise NotImplementedError("%r" % self)
+
def contains_op(a, b):
return a.contains(b)
contains_op = staticmethod(contains_op)
@@ -511,37 +507,44 @@ class StrategizedProperty(MapperProperty):
``StrategizedOption`` objects via the Query.options() method.
"""
- def _get_context_strategy(self, context):
- path = context.path
- return self._get_strategy(context.attributes.get(("loaderstrategy", path), self.strategy.__class__))
-
+ def __get_context_strategy(self, context, path):
+ cls = context.attributes.get(("loaderstrategy", path), None)
+ if cls:
+ try:
+ return self.__all_strategies[cls]
+ except KeyError:
+ return self.__init_strategy(cls)
+ else:
+ return self.strategy
+
def _get_strategy(self, cls):
try:
- return self._all_strategies[cls]
+ return self.__all_strategies[cls]
except KeyError:
- # cache the located strategy per class for faster re-lookup
- strategy = cls(self)
- strategy.init()
- self._all_strategies[cls] = strategy
- return strategy
-
- def setup(self, querycontext, **kwargs):
- self._get_context_strategy(querycontext).setup_query(querycontext, **kwargs)
+ return self.__init_strategy(cls)
+
+ def __init_strategy(self, cls):
+ self.__all_strategies[cls] = strategy = cls(self)
+ strategy.init()
+ return strategy
+
+ def setup(self, context, entity, path, adapter, **kwargs):
+ self.__get_context_strategy(context, path + (self.key,)).setup_query(context, entity, path, adapter, **kwargs)
- def create_row_processor(self, selectcontext, mapper, row):
- return self._get_context_strategy(selectcontext).create_row_processor(selectcontext, mapper, row)
+ def create_row_processor(self, context, path, mapper, row, adapter):
+ return self.__get_context_strategy(context, path + (self.key,)).create_row_processor(context, path, mapper, row, adapter)
def do_init(self):
- self._all_strategies = {}
- self.strategy = self._get_strategy(self.strategy_class)
+ self.__all_strategies = {}
+ self.strategy = self.__init_strategy(self.strategy_class)
if self.is_primary():
self.strategy.init_class_attribute()
-def build_path(mapper, key, prev=None):
+def build_path(entity, key, prev=None):
if prev:
- return prev + (mapper.base_mapper, key)
+ return prev + (entity, key)
else:
- return (mapper.base_mapper, key)
+ return (entity, key)
def serialize_path(path):
if path is None:
@@ -585,9 +588,9 @@ class ExtensionOption(MapperOption):
self.ext = ext
def process_query(self, query):
- query._extension = query._extension.copy()
- query._extension.insert(self.ext)
-
+ entity = query._generate_mapper_zero()
+ entity.extension = entity.extension.copy()
+ entity.extension.push(self.ext)
class PropertyOption(MapperOption):
"""A MapperOption that is applied to a property off the mapper or
@@ -607,60 +610,86 @@ class PropertyOption(MapperOption):
def _process(self, query, raiseerr):
if self._should_log_debug:
self.logger.debug("applying option to Query, property key '%s'" % self.key)
- paths = self._get_paths(query, raiseerr)
+ paths = self.__get_paths(query, raiseerr)
if paths:
self.process_query_property(query, paths)
def process_query_property(self, query, paths):
pass
+
+ def __find_entity(self, query, mapper, raiseerr):
+ from sqlalchemy.orm.util import _class_to_mapper, _is_aliased_class
+
+ if _is_aliased_class(mapper):
+ searchfor = mapper
+ else:
+ searchfor = _class_to_mapper(mapper).base_mapper
- def _get_paths(self, query, raiseerr):
+ for ent in query._mapper_entities:
+ if ent.path_entity is searchfor:
+ return ent
+ else:
+ if raiseerr:
+ raise sa_exc.ArgumentError("Can't find entity %s in Query. Current list: %r" % (searchfor, [str(m.path_entity) for m in query._entities]))
+ else:
+ return None
+
+ def __get_paths(self, query, raiseerr):
path = None
+ entity = None
l = []
+
current_path = list(query._current_path)
-
+
if self.mapper:
- global class_mapper
- if class_mapper is None:
- from sqlalchemy.orm import class_mapper
- mapper = self.mapper
- if isinstance(self.mapper, type):
- mapper = class_mapper(mapper)
- if mapper is not query.mapper and mapper not in [q.mapper for q in query._entities]:
- raise exceptions.ArgumentError("Can't find entity %s in Query. Current list: %r" % (str(mapper), [str(m) for m in query._entities]))
- else:
- mapper = query.mapper
- if isinstance(self.key, basestring):
- tokens = self.key.split('.')
- else:
- tokens = util.to_list(self.key)
+ entity = self.__find_entity(query, self.mapper, raiseerr)
+ mapper = entity.mapper
+ path_element = entity.path_entity
- for token in tokens:
- if isinstance(token, basestring):
- prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr)
- elif isinstance(token, PropComparator):
- prop = token.property
- token = prop.key
-
+ for key in util.to_list(self.key):
+ if isinstance(key, basestring):
+ tokens = key.split('.')
else:
- raise exceptions.ArgumentError("mapper option expects string key or list of attributes")
-
- if current_path and token == current_path[1]:
- current_path = current_path[2:]
- continue
+ tokens = [key]
+ for token in tokens:
+ if isinstance(token, basestring):
+ if not entity:
+ entity = query._entity_zero()
+ path_element = entity.path_entity
+ mapper = entity.mapper
+ prop = mapper.get_property(token, resolve_synonyms=True, raiseerr=raiseerr)
+ key = token
+ elif isinstance(token, PropComparator):
+ prop = token.property
+ if not entity:
+ entity = self.__find_entity(query, token.parententity, raiseerr)
+ if not entity:
+ return []
+ path_element = entity.path_entity
+ key = prop.key
+ else:
+ raise sa_exc.ArgumentError("mapper option expects string key or list of attributes")
+
+ if current_path and key == current_path[1]:
+ current_path = current_path[2:]
+ continue
- if prop is None:
- return []
- path = build_path(mapper, prop.key, path)
- l.append(path)
- if getattr(token, '_of_type', None):
- mapper = token._of_type
- else:
- mapper = getattr(prop, 'mapper', None)
+ if prop is None:
+ return []
+
+ path = build_path(path_element, prop.key, path)
+ l.append(path)
+ if getattr(token, '_of_type', None):
+ path_element = mapper = token._of_type
+ else:
+ path_element = mapper = getattr(prop, 'mapper', None)
+ if path_element:
+ path_element = path_element.base_mapper
+
return l
-PropertyOption.logger = logging.class_logger(PropertyOption)
-PropertyOption._should_log_debug = logging.is_debug_enabled(PropertyOption.logger)
+PropertyOption.logger = log.class_logger(PropertyOption)
+PropertyOption._should_log_debug = log.is_debug_enabled(PropertyOption.logger)
class AttributeExtension(object):
"""An abstract class which specifies `append`, `delete`, and `set`
@@ -732,10 +761,10 @@ class LoaderStrategy(object):
def init_class_attribute(self):
pass
- def setup_query(self, context, **kwargs):
+ def setup_query(self, context, entity, path, adapter, **kwargs):
pass
- def create_row_processor(self, selectcontext, mapper, row):
+ def create_row_processor(self, selectcontext, path, mapper, row, adapter):
"""Return row processing functions which fulfill the contract specified
by MapperProperty.create_row_processor.
@@ -744,3 +773,71 @@ class LoaderStrategy(object):
"""
raise NotImplementedError()
+
+ def __str__(self):
+ return str(self.parent_property)
+
+ def debug_callable(self, fn, logger, announcement, logfn):
+ if announcement:
+ logger.debug(announcement)
+ if logfn:
+ def call(*args, **kwargs):
+ logger.debug(logfn(*args, **kwargs))
+ return fn(*args, **kwargs)
+ return call
+ else:
+ return fn
+
+class InstrumentationManager(object):
+ """User-defined class instrumentation extension."""
+
+ # r4361 added a mandatory (cls) constructor to this interface.
+ # given that, perhaps class_ should be dropped from all of these
+ # signatures.
+
+ def __init__(self, class_):
+ pass
+
+ def manage(self, class_, manager):
+ setattr(class_, '_default_class_manager', manager)
+
+ def dispose(self, class_, manager):
+ delattr(class_, '_default_class_manager')
+
+ def manager_getter(self, class_):
+ def get(cls):
+ return cls._default_class_manager
+ return get
+
+ def instrument_attribute(self, class_, key, inst):
+ pass
+
+ def install_descriptor(self, class_, key, inst):
+ setattr(class_, key, inst)
+
+ def uninstall_descriptor(self, class_, key):
+ delattr(class_, key)
+
+ def install_member(self, class_, key, implementation):
+ setattr(class_, key, implementation)
+
+ def uninstall_member(self, class_, key):
+ delattr(class_, key)
+
+ def instrument_collection_class(self, class_, key, collection_class):
+ global collections
+ if collections is None:
+ from sqlalchemy.orm import collections
+ return collections.prepare_instrumentation(collection_class)
+
+ def get_instance_dict(self, class_, instance):
+ return instance.__dict__
+
+ def initialize_instance_dict(self, class_, instance):
+ pass
+
+ def install_state(self, class_, instance, state):
+ setattr(instance, '_default_state', state)
+
+ def state_getter(self, class_):
+ return lambda instance: getattr(instance, '_default_state')