summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2014-10-14 14:04:17 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2014-10-14 14:04:17 -0400
commit6de1a878702b8737b2257b89b478ead79a8d78cc (patch)
tree6d96365a5eeb6575c896c9df8cd4ec46f85f2e26
parent1e5ffa029a2b9adc1a6a3c83fb7e88a99d4e6448 (diff)
downloadsqlalchemy-6de1a878702b8737b2257b89b478ead79a8d78cc.tar.gz
- Improvements to the mechanism used by :class:`.Session` to locate
"binds" (e.g. engines to use), such engines can be associated with mixin classes, concrete subclasses, as well as a wider variety of table metadata such as joined inheritance tables. fixes #3035
-rw-r--r--doc/build/changelog/changelog_10.rst13
-rw-r--r--doc/build/changelog/migration_10.rst42
-rw-r--r--lib/sqlalchemy/orm/session.py96
-rw-r--r--test/orm/test_bind.py220
-rw-r--r--test/orm/test_session.py15
5 files changed, 336 insertions, 50 deletions
diff --git a/doc/build/changelog/changelog_10.rst b/doc/build/changelog/changelog_10.rst
index 8578c7883..66fa2ad26 100644
--- a/doc/build/changelog/changelog_10.rst
+++ b/doc/build/changelog/changelog_10.rst
@@ -22,6 +22,19 @@
on compatibility concerns, see :doc:`/changelog/migration_10`.
.. change::
+ :tags: bug, orm
+ :tickets: 3035
+
+ Improvements to the mechanism used by :class:`.Session` to locate
+ "binds" (e.g. engines to use), such engines can be associated with
+ mixin classes, concrete subclasses, as well as a wider variety
+ of table metadata such as joined inheritance tables.
+
+ .. seealso::
+
+ :ref:`bug_3035`
+
+ .. change::
:tags: bug, general
:tickets: 3218
diff --git a/doc/build/changelog/migration_10.rst b/doc/build/changelog/migration_10.rst
index 951e39603..dd8964f8b 100644
--- a/doc/build/changelog/migration_10.rst
+++ b/doc/build/changelog/migration_10.rst
@@ -468,6 +468,48 @@ object totally smokes both namedtuple and KeyedTuple::
:ticket:`3176`
+.. _bug_3035:
+
+Session.get_bind() handles a wider variety of inheritance scenarios
+-------------------------------------------------------------------
+
+The :meth:`.Session.get_bind` method is invoked whenever a query or unit
+of work flush process seeks to locate the database engine that corresponds
+to a particular class. The method has been improved to handle a variety
+of inheritance-oriented scenarios, including:
+
+* Binding to a Mixin or Abstract Class::
+
+ class MyClass(SomeMixin, Base):
+ __tablename__ = 'my_table'
+ # ...
+
+ session = Session(binds={SomeMixin: some_engine})
+
+
+* Binding to inherited concrete subclasses individually based on table::
+
+ class BaseClass(Base):
+ __tablename__ = 'base'
+
+ # ...
+
+ class ConcreteSubClass(BaseClass):
+ __tablename__ = 'concrete'
+
+ # ...
+
+ __mapper_args__ = {'concrete': True}
+
+
+ session = Session(binds={
+ base_table: some_engine,
+ concrete_table: some_other_engine
+ })
+
+
+:ticket:`3035`
+
.. _feature_3178:
New systems to safely emit parameterized warnings
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 13afcb357..db9d3a51d 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -641,14 +641,8 @@ class Session(_SessionClassMethods):
SessionExtension._adapt_listener(self, ext)
if binds is not None:
- for mapperortable, bind in binds.items():
- insp = inspect(mapperortable)
- if insp.is_selectable:
- self.bind_table(mapperortable, bind)
- elif insp.is_mapper:
- self.bind_mapper(mapperortable, bind)
- else:
- assert False
+ for key, bind in binds.items():
+ self._add_bind(key, bind)
if not self.autocommit:
self.begin()
@@ -1026,40 +1020,47 @@ class Session(_SessionClassMethods):
# TODO: + crystallize + document resolution order
# vis. bind_mapper/bind_table
- def bind_mapper(self, mapper, bind):
- """Bind operations for a mapper to a Connectable.
-
- mapper
- A mapper instance or mapped class
+ def _add_bind(self, key, bind):
+ try:
+ insp = inspect(key)
+ except sa_exc.NoInspectionAvailable:
+ if not isinstance(key, type):
+ raise exc.ArgumentError(
+ "Not acceptable bind target: %s" %
+ key)
+ else:
+ self.__binds[key] = bind
+ else:
+ if insp.is_selectable:
+ self.__binds[insp] = bind
+ elif insp.is_mapper:
+ self.__binds[insp.class_] = bind
+ for selectable in insp._all_tables:
+ self.__binds[selectable] = bind
+ else:
+ raise exc.ArgumentError(
+ "Not acceptable bind target: %s" %
+ key)
- bind
- Any Connectable: a :class:`.Engine` or :class:`.Connection`.
+ def bind_mapper(self, mapper, bind):
+ """Associate a :class:`.Mapper` with a "bind", e.g. a :class:`.Engine`
+ or :class:`.Connection`.
- All subsequent operations involving this mapper will use the given
- `bind`.
+ The given mapper is added to a lookup used by the
+ :meth:`.Session.get_bind` method.
"""
- if isinstance(mapper, type):
- mapper = class_mapper(mapper)
-
- self.__binds[mapper.base_mapper] = bind
- for t in mapper._all_tables:
- self.__binds[t] = bind
+ self._add_bind(mapper, bind)
def bind_table(self, table, bind):
- """Bind operations on a Table to a Connectable.
-
- table
- A :class:`.Table` instance
+ """Associate a :class:`.Table` with a "bind", e.g. a :class:`.Engine`
+ or :class:`.Connection`.
- bind
- Any Connectable: a :class:`.Engine` or :class:`.Connection`.
-
- All subsequent operations involving this :class:`.Table` will use the
- given `bind`.
+ The given mapper is added to a lookup used by the
+ :meth:`.Session.get_bind` method.
"""
- self.__binds[table] = bind
+ self._add_bind(table, bind)
def get_bind(self, mapper=None, clause=None):
"""Return a "bind" to which this :class:`.Session` is bound.
@@ -1113,6 +1114,7 @@ class Session(_SessionClassMethods):
bound :class:`.MetaData`.
"""
+
if mapper is clause is None:
if self.bind:
return self.bind
@@ -1122,15 +1124,23 @@ class Session(_SessionClassMethods):
"Connection, and no context was provided to locate "
"a binding.")
- c_mapper = mapper is not None and _class_to_mapper(mapper) or None
+ if mapper is not None:
+ try:
+ mapper = inspect(mapper)
+ except sa_exc.NoInspectionAvailable:
+ if isinstance(mapper, type):
+ raise exc.UnmappedClassError(mapper)
+ else:
+ raise
- # manually bound?
if self.__binds:
- if c_mapper:
- if c_mapper.base_mapper in self.__binds:
- return self.__binds[c_mapper.base_mapper]
- elif c_mapper.mapped_table in self.__binds:
- return self.__binds[c_mapper.mapped_table]
+ if mapper:
+ for cls in mapper.class_.__mro__:
+ if cls in self.__binds:
+ return self.__binds[cls]
+ if clause is None:
+ clause = mapper.mapped_table
+
if clause is not None:
for t in sql_util.find_tables(clause, include_crud=True):
if t in self.__binds:
@@ -1142,12 +1152,12 @@ class Session(_SessionClassMethods):
if isinstance(clause, sql.expression.ClauseElement) and clause.bind:
return clause.bind
- if c_mapper and c_mapper.mapped_table.bind:
- return c_mapper.mapped_table.bind
+ if mapper and mapper.mapped_table.bind:
+ return mapper.mapped_table.bind
context = []
if mapper is not None:
- context.append('mapper %s' % c_mapper)
+ context.append('mapper %s' % mapper)
if clause is not None:
context.append('SQL expression')
diff --git a/test/orm/test_bind.py b/test/orm/test_bind.py
index 3e5af0cba..33cd66ebc 100644
--- a/test/orm/test_bind.py
+++ b/test/orm/test_bind.py
@@ -1,13 +1,14 @@
from sqlalchemy.testing import assert_raises_message
-from sqlalchemy import MetaData, Integer
+from sqlalchemy import MetaData, Integer, ForeignKey
from sqlalchemy.testing.schema import Table
from sqlalchemy.testing.schema import Column
from sqlalchemy.orm import mapper, create_session
import sqlalchemy as sa
from sqlalchemy import testing
-from sqlalchemy.testing import fixtures, eq_, engines
+from sqlalchemy.testing import fixtures, eq_, engines, is_
from sqlalchemy.orm import relationship, Session, backref, sessionmaker
from test.orm import _fixtures
+from sqlalchemy.testing.mock import Mock
class BindIntegrationTest(_fixtures.FixtureTest):
@@ -249,3 +250,218 @@ class SessionBindTest(fixtures.MappedTest):
('Could not locate a bind configured on Mapper|Foo|test_table '
'or this Session'),
sess.flush)
+
+
+class GetBindTest(fixtures.MappedTest):
+ @classmethod
+ def define_tables(cls, metadata):
+ Table(
+ 'base_table', metadata,
+ Column('id', Integer, primary_key=True)
+ )
+ Table(
+ 'w_mixin_table', metadata,
+ Column('id', Integer, primary_key=True)
+ )
+ Table(
+ 'joined_sub_table', metadata,
+ Column('id', ForeignKey('base_table.id'), primary_key=True)
+ )
+ Table(
+ 'concrete_sub_table', metadata,
+ Column('id', Integer, primary_key=True)
+ )
+
+ @classmethod
+ def setup_classes(cls):
+ class MixinOne(cls.Basic):
+ pass
+
+ class BaseClass(cls.Basic):
+ pass
+
+ class ClassWMixin(MixinOne, cls.Basic):
+ pass
+
+ class JoinedSubClass(BaseClass):
+ pass
+
+ class ConcreteSubClass(BaseClass):
+ pass
+
+ @classmethod
+ def setup_mappers(cls):
+ mapper(cls.classes.ClassWMixin, cls.tables.w_mixin_table)
+ mapper(cls.classes.BaseClass, cls.tables.base_table)
+ mapper(
+ cls.classes.JoinedSubClass,
+ cls.tables.joined_sub_table, inherits=cls.classes.BaseClass)
+ mapper(
+ cls.classes.ConcreteSubClass,
+ cls.tables.concrete_sub_table, inherits=cls.classes.BaseClass,
+ concrete=True)
+
+ def _fixture(self, binds):
+ return Session(binds=binds)
+
+ def test_fallback_table_metadata(self):
+ session = self._fixture({})
+ is_(
+ session.get_bind(self.classes.BaseClass),
+ testing.db
+ )
+
+ def test_bind_base_table_base_class(self):
+ base_class_bind = Mock()
+ session = self._fixture({
+ self.tables.base_table: base_class_bind
+ })
+
+ is_(
+ session.get_bind(self.classes.BaseClass),
+ base_class_bind
+ )
+
+ def test_bind_base_table_joined_sub_class(self):
+ base_class_bind = Mock()
+ session = self._fixture({
+ self.tables.base_table: base_class_bind
+ })
+
+ is_(
+ session.get_bind(self.classes.BaseClass),
+ base_class_bind
+ )
+ is_(
+ session.get_bind(self.classes.JoinedSubClass),
+ base_class_bind
+ )
+
+ def test_bind_joined_sub_table_joined_sub_class(self):
+ base_class_bind = Mock(name='base')
+ joined_class_bind = Mock(name='joined')
+ session = self._fixture({
+ self.tables.base_table: base_class_bind,
+ self.tables.joined_sub_table: joined_class_bind
+ })
+
+ is_(
+ session.get_bind(self.classes.BaseClass),
+ base_class_bind
+ )
+ # joined table inheritance has to query based on the base
+ # table, so this is what we expect
+ is_(
+ session.get_bind(self.classes.JoinedSubClass),
+ base_class_bind
+ )
+
+ def test_bind_base_table_concrete_sub_class(self):
+ base_class_bind = Mock()
+ session = self._fixture({
+ self.tables.base_table: base_class_bind
+ })
+
+ is_(
+ session.get_bind(self.classes.ConcreteSubClass),
+ testing.db
+ )
+
+ def test_bind_sub_table_concrete_sub_class(self):
+ base_class_bind = Mock(name='base')
+ concrete_sub_bind = Mock(name='concrete')
+
+ session = self._fixture({
+ self.tables.base_table: base_class_bind,
+ self.tables.concrete_sub_table: concrete_sub_bind
+ })
+
+ is_(
+ session.get_bind(self.classes.BaseClass),
+ base_class_bind
+ )
+ is_(
+ session.get_bind(self.classes.ConcreteSubClass),
+ concrete_sub_bind
+ )
+
+ def test_bind_base_class_base_class(self):
+ base_class_bind = Mock()
+ session = self._fixture({
+ self.classes.BaseClass: base_class_bind
+ })
+
+ is_(
+ session.get_bind(self.classes.BaseClass),
+ base_class_bind
+ )
+
+ def test_bind_mixin_class_simple_class(self):
+ base_class_bind = Mock()
+ session = self._fixture({
+ self.classes.MixinOne: base_class_bind
+ })
+
+ is_(
+ session.get_bind(self.classes.ClassWMixin),
+ base_class_bind
+ )
+
+ def test_bind_base_class_joined_sub_class(self):
+ base_class_bind = Mock()
+ session = self._fixture({
+ self.classes.BaseClass: base_class_bind
+ })
+
+ is_(
+ session.get_bind(self.classes.JoinedSubClass),
+ base_class_bind
+ )
+
+ def test_bind_joined_sub_class_joined_sub_class(self):
+ base_class_bind = Mock(name='base')
+ joined_class_bind = Mock(name='joined')
+ session = self._fixture({
+ self.classes.BaseClass: base_class_bind,
+ self.classes.JoinedSubClass: joined_class_bind
+ })
+
+ is_(
+ session.get_bind(self.classes.BaseClass),
+ base_class_bind
+ )
+ is_(
+ session.get_bind(self.classes.JoinedSubClass),
+ joined_class_bind
+ )
+
+ def test_bind_base_class_concrete_sub_class(self):
+ base_class_bind = Mock()
+ session = self._fixture({
+ self.classes.BaseClass: base_class_bind
+ })
+
+ is_(
+ session.get_bind(self.classes.ConcreteSubClass),
+ base_class_bind
+ )
+
+ def test_bind_sub_class_concrete_sub_class(self):
+ base_class_bind = Mock(name='base')
+ concrete_sub_bind = Mock(name='concrete')
+
+ session = self._fixture({
+ self.classes.BaseClass: base_class_bind,
+ self.classes.ConcreteSubClass: concrete_sub_bind
+ })
+
+ is_(
+ session.get_bind(self.classes.BaseClass),
+ base_class_bind
+ )
+ is_(
+ session.get_bind(self.classes.ConcreteSubClass),
+ concrete_sub_bind
+ )
+
+
diff --git a/test/orm/test_session.py b/test/orm/test_session.py
index 06d1d7334..b0b00d5ed 100644
--- a/test/orm/test_session.py
+++ b/test/orm/test_session.py
@@ -1403,14 +1403,19 @@ class SessionInterface(fixtures.TestBase):
eq_(watchdog, instance_methods,
watchdog.symmetric_difference(instance_methods))
- def _test_class_guards(self, user_arg):
+ def _test_class_guards(self, user_arg, is_class=True):
watchdog = set()
def raises_(method, *args, **kw):
watchdog.add(method)
callable_ = getattr(create_session(), method)
- assert_raises(sa.orm.exc.UnmappedClassError,
- callable_, *args, **kw)
+ if is_class:
+ assert_raises(
+ sa.orm.exc.UnmappedClassError,
+ callable_, *args, **kw)
+ else:
+ assert_raises(
+ sa.exc.NoInspectionAvailable, callable_, *args, **kw)
raises_('connection', mapper=user_arg)
@@ -1433,7 +1438,7 @@ class SessionInterface(fixtures.TestBase):
def test_unmapped_primitives(self):
for prim in ('doh', 123, ('t', 'u', 'p', 'l', 'e')):
self._test_instance_guards(prim)
- self._test_class_guards(prim)
+ self._test_class_guards(prim, is_class=False)
def test_unmapped_class_for_instance(self):
class Unmapped(object):
@@ -1457,7 +1462,7 @@ class SessionInterface(fixtures.TestBase):
self._map_it(Mapped)
self._test_instance_guards(early)
- self._test_class_guards(early)
+ self._test_class_guards(early, is_class=False)
class TLTransactionTest(fixtures.MappedTest):