summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--nova/context.py27
-rw-r--r--nova/db/api.py5
-rw-r--r--nova/db/sqlalchemy/api.py41
-rw-r--r--nova/tests/unit/db/test_db_api.py7
-rw-r--r--nova/tests/unit/test_context.py15
5 files changed, 90 insertions, 5 deletions
diff --git a/nova/context.py b/nova/context.py
index 01849c5903..808f5dd3df 100644
--- a/nova/context.py
+++ b/nova/context.py
@@ -17,6 +17,7 @@
"""RequestContext: context for requests that persist through all of nova."""
+from contextlib import contextmanager
import copy
from keystoneauth1.access import service_catalog as ksa_service_catalog
@@ -141,6 +142,12 @@ class RequestContext(context.RequestContext):
self.user_name = user_name
self.project_name = project_name
self.is_admin = is_admin
+
+ # NOTE(dheeraj): The following attribute is used by cellsv2 to store
+ # connection information for connecting to the target cell.
+ # It is only manipulated using the target_cell contextmanager
+ # provided by this module
+ self.db_connection = None
self.user_auth_plugin = user_auth_plugin
if self.is_admin is None:
self.is_admin = policy.check_is_admin(self)
@@ -272,3 +279,23 @@ def authorize_quota_class_context(context, class_name):
raise exception.Forbidden()
elif context.quota_class != class_name:
raise exception.Forbidden()
+
+
+@contextmanager
+def target_cell(context, cell_mapping):
+ """Adds database connection information to the context for communicating
+ with the given target cell.
+
+ :param context: The RequestContext to add database connection information
+ :param cell_mapping: A objects.CellMapping object
+ """
+ original_db_connection = context.db_connection
+ # avoid circular import
+ from nova import db
+ connection_string = cell_mapping.database_connection
+ context.db_connection = db.create_context_manager(connection_string)
+
+ try:
+ yield context
+ finally:
+ context.db_connection = original_db_connection
diff --git a/nova/db/api.py b/nova/db/api.py
index cb1ff7722f..39b04c731f 100644
--- a/nova/db/api.py
+++ b/nova/db/api.py
@@ -86,6 +86,11 @@ def not_equal(*values):
return IMPL.not_equal(*values)
+def create_context_manager(connection):
+ """Return a context manager for a cell database connection."""
+ return IMPL.create_context_manager(connection=connection)
+
+
###################
diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py
index 511c87dd92..81d307b10c 100644
--- a/nova/db/sqlalchemy/api.py
+++ b/nova/db/sqlalchemy/api.py
@@ -135,9 +135,9 @@ main_context_manager = enginefacade.transaction_context()
api_context_manager = enginefacade.transaction_context()
-def _get_db_conf(conf_group):
+def _get_db_conf(conf_group, connection=None):
kw = dict(
- connection=conf_group.connection,
+ connection=connection or conf_group.connection,
slave_connection=conf_group.slave_connection,
sqlite_fk=False,
__autocommit=True,
@@ -155,14 +155,45 @@ def _get_db_conf(conf_group):
return kw
+def _context_manager_from_context(context):
+ if context:
+ try:
+ return context.db_connection
+ except AttributeError:
+ pass
+
+
def configure(conf):
main_context_manager.configure(**_get_db_conf(conf.database))
api_context_manager.configure(**_get_db_conf(conf.api_database))
-def get_engine(use_slave=False):
- return main_context_manager.get_legacy_facade().get_engine(
- use_slave=use_slave)
+def create_context_manager(connection=None):
+ """Create a database context manager object.
+
+ : param connection: The database connection string
+ """
+ ctxt_mgr = enginefacade.transaction_context()
+ ctxt_mgr.configure(**_get_db_conf(CONF.database, connection=connection))
+ return ctxt_mgr
+
+
+def get_context_manager(context):
+ """Get a database context manager object.
+
+ :param context: The request context that can contain a context manager
+ """
+ return _context_manager_from_context(context) or main_context_manager
+
+
+def get_engine(use_slave=False, context=None):
+ """Get a database engine object.
+
+ :param use_slave: Whether to use the slave connection
+ :param context: The request context that can contain a context manager
+ """
+ ctxt_mgr = _context_manager_from_context(context) or main_context_manager
+ return ctxt_mgr.get_legacy_facade().get_engine(use_slave=use_slave)
def get_api_engine():
diff --git a/nova/tests/unit/db/test_db_api.py b/nova/tests/unit/db/test_db_api.py
index 53e9d9aec4..870ed3ef31 100644
--- a/nova/tests/unit/db/test_db_api.py
+++ b/nova/tests/unit/db/test_db_api.py
@@ -1070,6 +1070,13 @@ class SqlAlchemyDbApiNoDbTestCase(test.NoDBTestCase):
mock_create_facade.assert_called_once_with()
mock_facade.get_engine.assert_called_once_with(use_slave=False)
+ def test_get_db_conf_with_connection(self):
+ mock_conf_group = mock.MagicMock()
+ mock_conf_group.connection = 'fakemain://'
+ db_conf = sqlalchemy_api._get_db_conf(mock_conf_group,
+ connection='fake://')
+ self.assertEqual('fake://', db_conf['connection'])
+
@mock.patch.object(sqlalchemy_api.api_context_manager._factory,
'get_legacy_facade')
def test_get_api_engine(self, mock_create_facade):
diff --git a/nova/tests/unit/test_context.py b/nova/tests/unit/test_context.py
index 5a4651d803..65feeb71a4 100644
--- a/nova/tests/unit/test_context.py
+++ b/nova/tests/unit/test_context.py
@@ -12,10 +12,12 @@
# License for the specific language governing permissions and limitations
# under the License.
+import mock
from oslo_context import context as o_context
from oslo_context import fixture as o_fixture
from nova import context
+from nova import objects
from nova import test
@@ -223,3 +225,16 @@ class ContextTestCase(test.NoDBTestCase):
self.assertEqual('222', ctx.project_id)
values2 = ctx.to_dict()
self.assertEqual(values, values2)
+
+ @mock.patch('nova.db.create_context_manager')
+ def test_target_cell(self, mock_create_ctxt_mgr):
+ mock_create_ctxt_mgr.return_value = mock.sentinel.cm
+ ctxt = context.RequestContext('111',
+ '222',
+ roles=['admin', 'weasel'])
+ # Verify the existing db_connection, if any, is restored
+ ctxt.db_connection = mock.sentinel.db_conn
+ mapping = objects.CellMapping(database_connection='fake://')
+ with context.target_cell(ctxt, mapping):
+ self.assertEqual(ctxt.db_connection, mock.sentinel.cm)
+ self.assertEqual(mock.sentinel.db_conn, ctxt.db_connection)