diff options
-rw-r--r-- | neutron/api/rpc/handlers/securitygroups_rpc.py | 14 | ||||
-rw-r--r-- | neutron/db/securitygroups_rpc_base.py | 13 | ||||
-rw-r--r-- | neutron/plugins/ml2/db.py | 72 | ||||
-rw-r--r-- | neutron/plugins/ml2/plugin.py | 18 | ||||
-rw-r--r-- | neutron/tests/unit/ml2/test_security_group.py | 107 |
5 files changed, 165 insertions, 59 deletions
diff --git a/neutron/api/rpc/handlers/securitygroups_rpc.py b/neutron/api/rpc/handlers/securitygroups_rpc.py index 2a748cfbcc..e4a16b2f58 100644 --- a/neutron/api/rpc/handlers/securitygroups_rpc.py +++ b/neutron/api/rpc/handlers/securitygroups_rpc.py @@ -36,15 +36,11 @@ class SecurityGroupServerRpcCallback(n_rpc.RpcCallback): return manager.NeutronManager.get_plugin() def _get_devices_info(self, devices): - devices_info = {} - for device in devices: - port = self.plugin.get_port_from_device(device) - if not port: - continue - if port['device_owner'].startswith('network:'): - continue - devices_info[port['id']] = port - return devices_info + return dict( + (port['id'], port) + for port in self.plugin.get_ports_from_devices(devices) + if port and not port['device_owner'].startswith('network:') + ) def security_group_rules_for_devices(self, context, **kwargs): """Callback method to return security group rules for each port. diff --git a/neutron/db/securitygroups_rpc_base.py b/neutron/db/securitygroups_rpc_base.py index 233a50a3ac..f570fea7f6 100644 --- a/neutron/db/securitygroups_rpc_base.py +++ b/neutron/db/securitygroups_rpc_base.py @@ -40,7 +40,7 @@ class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin): def get_port_from_device(self, device): """Get port dict from device name on an agent. - Subclass must provide this method. + Subclass must provide this method or get_ports_from_devices. :param device: device name which identifies a port on the agent side. What is specified in "device" depends on a plugin agent implementation. @@ -54,9 +54,18 @@ class SecurityGroupServerRpcMixin(sg_db.SecurityGroupDbMixin): - security_group_source_groups - fixed_ips """ - raise NotImplementedError(_("%s must implement get_port_from_device.") + raise NotImplementedError(_("%s must implement get_port_from_device " + "or get_ports_from_devices.") % self.__class__.__name__) + def get_ports_from_devices(self, devices): + """Bulk method of get_port_from_device. + + Subclasses may override this to provide better performance for DB + queries, backend calls, etc. + """ + return [self.get_port_from_device(device) for device in devices] + def create_security_group_rule(self, context, security_group_rule): bulk_rule = {'security_group_rules': [security_group_rule]} rule = self.create_security_group_rule_bulk_native(context, diff --git a/neutron/plugins/ml2/db.py b/neutron/plugins/ml2/db.py index d8caa9384a..40e1c22e52 100644 --- a/neutron/plugins/ml2/db.py +++ b/neutron/plugins/ml2/db.py @@ -13,6 +13,9 @@ # License for the specific language governing permissions and limitations # under the License. +import collections + +from sqlalchemy import or_ from sqlalchemy.orm import exc from oslo.db import exception as db_exc @@ -30,6 +33,9 @@ from neutron.plugins.ml2 import models LOG = log.getLogger(__name__) +# limit the number of port OR LIKE statements in one query +MAX_PORTS_PER_QUERY = 500 + def _make_segment_dict(record): """Make a segment dictionary out of a DB record.""" @@ -206,32 +212,64 @@ def get_port_from_device_mac(device_mac): return qry.first() -def get_port_and_sgs(port_id): - """Get port from database with security group info.""" +def get_ports_and_sgs(port_ids): + """Get ports from database with security group info.""" + + # break large queries into smaller parts + if len(port_ids) > MAX_PORTS_PER_QUERY: + LOG.debug("Number of ports %(pcount)s exceeds the maximum per " + "query %(maxp)s. Partitioning queries.", + {'pcount': len(port_ids), 'maxp': MAX_PORTS_PER_QUERY}) + return (get_ports_and_sgs(port_ids[:MAX_PORTS_PER_QUERY]) + + get_ports_and_sgs(port_ids[MAX_PORTS_PER_QUERY:])) + + LOG.debug("get_ports_and_sgs() called for port_ids %s", port_ids) - LOG.debug(_("get_port_and_sgs() called for port_id %s"), port_id) + if not port_ids: + # if port_ids is empty, avoid querying to DB to ask it for nothing + return [] + ports_to_sg_ids = get_sg_ids_grouped_by_port(port_ids) + return [make_port_dict_with_security_groups(port, sec_groups) + for port, sec_groups in ports_to_sg_ids.iteritems()] + + +def get_sg_ids_grouped_by_port(port_ids): + sg_ids_grouped_by_port = collections.defaultdict(list) session = db_api.get_session() sg_binding_port = sg_db.SecurityGroupPortBinding.port_id with session.begin(subtransactions=True): + # partial UUIDs must be individually matched with startswith. + # full UUIDs may be matched directly in an IN statement + partial_uuids = set(port_id for port_id in port_ids + if not uuidutils.is_uuid_like(port_id)) + full_uuids = set(port_ids) - partial_uuids + or_criteria = [models_v2.Port.id.startswith(port_id) + for port_id in partial_uuids] + if full_uuids: + or_criteria.append(models_v2.Port.id.in_(full_uuids)) + query = session.query(models_v2.Port, sg_db.SecurityGroupPortBinding.security_group_id) query = query.outerjoin(sg_db.SecurityGroupPortBinding, models_v2.Port.id == sg_binding_port) - query = query.filter(models_v2.Port.id.startswith(port_id)) - port_and_sgs = query.all() - if not port_and_sgs: - return - port = port_and_sgs[0][0] - plugin = manager.NeutronManager.get_plugin() - port_dict = plugin._make_port_dict(port) - port_dict['security_groups'] = [ - sg_id for port_, sg_id in port_and_sgs if sg_id] - port_dict['security_group_rules'] = [] - port_dict['security_group_source_groups'] = [] - port_dict['fixed_ips'] = [ip['ip_address'] - for ip in port['fixed_ips']] - return port_dict + query = query.filter(or_(*or_criteria)) + + for port, sg_id in query: + if sg_id: + sg_ids_grouped_by_port[port].append(sg_id) + return sg_ids_grouped_by_port + + +def make_port_dict_with_security_groups(port, sec_groups): + plugin = manager.NeutronManager.get_plugin() + port_dict = plugin._make_port_dict(port) + port_dict['security_groups'] = sec_groups + port_dict['security_group_rules'] = [] + port_dict['security_group_source_groups'] = [] + port_dict['fixed_ips'] = [ip['ip_address'] + for ip in port['fixed_ips']] + return port_dict def get_port_binding_host(port_id): diff --git a/neutron/plugins/ml2/plugin.py b/neutron/plugins/ml2/plugin.py index 72cf151006..d29deda6ce 100644 --- a/neutron/plugins/ml2/plugin.py +++ b/neutron/plugins/ml2/plugin.py @@ -1156,12 +1156,18 @@ class Ml2Plugin(db_base_plugin_v2.NeutronDbPluginV2, port_host = db.get_port_binding_host(port_id) return (port_host == host) - def get_port_from_device(self, device): - port_id = self._device_to_port_id(device) - port = db.get_port_and_sgs(port_id) - if port: - port['device'] = device - return port + def get_ports_from_devices(self, devices): + port_ids_to_devices = dict((self._device_to_port_id(device), device) + for device in devices) + port_ids = port_ids_to_devices.keys() + ports = db.get_ports_and_sgs(port_ids) + for port in ports: + # map back to original requested id + port_id = next((port_id for port_id in port_ids + if port['id'].startswith(port_id)), None) + port['device'] = port_ids_to_devices.get(port_id) + + return ports def _device_to_port_id(self, device): # REVISIT(rkukura): Consider calling into MechanismDrivers to diff --git a/neutron/tests/unit/ml2/test_security_group.py b/neutron/tests/unit/ml2/test_security_group.py index 39c3cc2bae..cc8468ae23 100644 --- a/neutron/tests/unit/ml2/test_security_group.py +++ b/neutron/tests/unit/ml2/test_security_group.py @@ -14,11 +14,15 @@ # License for the specific language governing permissions and limitations # under the License. +import contextlib +import math import mock from neutron.api.v2 import attributes +from neutron.common import constants as const from neutron.extensions import securitygroup as ext_sg from neutron import manager +from neutron.tests.unit import test_api_v2 from neutron.tests.unit import test_extension_security_group as test_sg from neutron.tests.unit import test_security_groups_rpc as test_sg_rpc @@ -55,38 +59,91 @@ class TestMl2SecurityGroups(Ml2SecurityGroupsTestCase, plugin = manager.NeutronManager.get_plugin() plugin.start_rpc_listeners() - def test_security_group_get_port_from_device(self): + def _make_port_with_new_sec_group(self, net_id): + sg = self._make_security_group(self.fmt, 'name', 'desc') + port = self._make_port( + self.fmt, net_id, security_groups=[sg['security_group']['id']]) + return port['port'] + + def test_security_group_get_ports_from_devices(self): with self.network() as n: with self.subnet(n): - with self.security_group() as sg: - security_group_id = sg['security_group']['id'] - res = self._create_port(self.fmt, n['network']['id']) - port = self.deserialize(self.fmt, res) - fixed_ips = port['port']['fixed_ips'] - data = {'port': {'fixed_ips': fixed_ips, - 'name': port['port']['name'], - ext_sg.SECURITYGROUPS: - [security_group_id]}} - - req = self.new_update_request('ports', data, - port['port']['id']) - res = self.deserialize(self.fmt, - req.get_response(self.api)) - port_id = res['port']['id'] - plugin = manager.NeutronManager.get_plugin() - port_dict = plugin.get_port_from_device(port_id) - self.assertEqual(port_id, port_dict['id']) - self.assertEqual([security_group_id], + port1 = self._make_port_with_new_sec_group(n['network']['id']) + port2 = self._make_port_with_new_sec_group(n['network']['id']) + plugin = manager.NeutronManager.get_plugin() + # should match full ID and starting chars + ports = plugin.get_ports_from_devices( + [port1['id'], port2['id'][0:8]]) + self.assertEqual(2, len(ports)) + for port_dict in ports: + p = port1 if port1['id'] == port_dict['id'] else port2 + self.assertEqual(p['id'], port_dict['id']) + self.assertEqual(p['security_groups'], port_dict[ext_sg.SECURITYGROUPS]) self.assertEqual([], port_dict['security_group_rules']) - self.assertEqual([fixed_ips[0]['ip_address']], + self.assertEqual([p['fixed_ips'][0]['ip_address']], port_dict['fixed_ips']) - self._delete('ports', port_id) + self._delete('ports', p['id']) + + def test_security_group_get_ports_from_devices_with_bad_id(self): + plugin = manager.NeutronManager.get_plugin() + ports = plugin.get_ports_from_devices(['bad_device_id']) + self.assertFalse(ports) - def test_security_group_get_port_from_device_with_no_port(self): + def test_security_group_no_db_calls_with_no_ports(self): + plugin = manager.NeutronManager.get_plugin() + with mock.patch( + 'neutron.plugins.ml2.db.get_sg_ids_grouped_by_port' + ) as get_mock: + self.assertFalse(plugin.get_ports_from_devices([])) + self.assertFalse(get_mock.called) + + def test_large_port_count_broken_into_parts(self): + plugin = manager.NeutronManager.get_plugin() + max_ports_per_query = 5 + ports_to_query = 73 + for max_ports_per_query in (1, 2, 5, 7, 9, 31): + with contextlib.nested( + mock.patch('neutron.plugins.ml2.db.MAX_PORTS_PER_QUERY', + new=max_ports_per_query), + mock.patch('neutron.plugins.ml2.db.get_sg_ids_grouped_by_port', + return_value={}), + ) as (max_mock, get_mock): + plugin.get_ports_from_devices( + ['%s%s' % (const.TAP_DEVICE_PREFIX, i) + for i in range(ports_to_query)]) + all_call_args = map(lambda x: x[1][0], get_mock.mock_calls) + last_call_args = all_call_args.pop() + # all but last should be getting MAX_PORTS_PER_QUERY ports + self.assertTrue( + all(map(lambda x: len(x) == max_ports_per_query, + all_call_args)) + ) + remaining = ports_to_query % max_ports_per_query + if remaining: + self.assertEqual(remaining, len(last_call_args)) + # should be broken into ceil(total/MAX_PORTS_PER_QUERY) calls + self.assertEqual( + math.ceil(ports_to_query / float(max_ports_per_query)), + get_mock.call_count + ) + + def test_full_uuids_skip_port_id_lookup(self): plugin = manager.NeutronManager.get_plugin() - port_dict = plugin.get_port_from_device('bad_device_id') - self.assertIsNone(port_dict) + # when full UUIDs are provided, the _or statement should only + # have one matching 'IN' critiera for all of the IDs + with contextlib.nested( + mock.patch('neutron.plugins.ml2.db.or_'), + mock.patch('neutron.plugins.ml2.db.db_api.get_session') + ) as (or_mock, sess_mock): + fmock = sess_mock.query.return_value.outerjoin.return_value.filter + # return no ports to exit the method early since we are mocking + # the query + fmock.return_value.all.return_value = [] + plugin.get_ports_from_devices([test_api_v2._uuid(), + test_api_v2._uuid()]) + # the or_ function should only have one argument + or_mock.assert_called_once_with(mock.ANY) class TestMl2SecurityGroupsXML(TestMl2SecurityGroups): |