diff options
author | Julia Kreger <juliaashleykreger@gmail.com> | 2022-09-13 14:58:15 -0700 |
---|---|---|
committer | Julia Kreger <juliaashleykreger@gmail.com> | 2022-10-14 07:58:27 -0700 |
commit | c2df29e52f44e39c9fa3cc6ef0820749db1c8647 (patch) | |
tree | 7869d05cf829ed147e0a21f3b0643df75daeee82 /ironic/db/sqlalchemy | |
parent | 49e085583dec81c63d19f80a9ba067e38d8043ae (diff) | |
download | ironic-c2df29e52f44e39c9fa3cc6ef0820749db1c8647.tar.gz |
Phase 2 - SQLAlchemy 2.0 Compatability
* Changed common exception imports from SQLAlchemy for ORM
query types which are now originated from the main exception
definition set.
* Changed base join option usage to use objects instead of labels,
and defaulted all multi-row result sets to return data using
"selectinload" as opposed to operating with a join query to
avoid need to de-duplicate all result sets.
* Changed DeployTemplates to utilize objects instead of
field names for queries, and updated the associated join ORM
model's relationship record between DeployTemplate and
DeployTemplateSteps.
* Changed Ports, Chassis, Conductor, Volume Target/Connector
queries to lean towards use of select/update/delete queries as
opposed to ORM queries. Most of these changes revolved around
references of field names as opposed to model objects.
* This change also labels a few lines as "noqa", which is a result
of the style check rules getting triggered on statements as
needed for SQLAlchemy.
Change-Id: I651ec4b50c79be6aa8c798ee27957ed720a578d8
Diffstat (limited to 'ironic/db/sqlalchemy')
-rw-r--r-- | ironic/db/sqlalchemy/api.py | 218 | ||||
-rw-r--r-- | ironic/db/sqlalchemy/models.py | 5 |
2 files changed, 143 insertions, 80 deletions
diff --git a/ironic/db/sqlalchemy/api.py b/ironic/db/sqlalchemy/api.py index eadfce776..b05af3637 100644 --- a/ironic/db/sqlalchemy/api.py +++ b/ironic/db/sqlalchemy/api.py @@ -33,7 +33,7 @@ from oslo_utils import uuidutils from osprofiler import sqlalchemy as osp_sqlalchemy import sqlalchemy as sa from sqlalchemy import or_ -from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound +from sqlalchemy.exc import NoResultFound, MultipleResultsFound from sqlalchemy.orm import joinedload from sqlalchemy.orm import Load from sqlalchemy.orm import selectinload @@ -139,8 +139,8 @@ def _get_node_query_with_all_for_single_node(): # 2.43 seconds to obtain all nodes from SQLAlchemy (10k nodes) # 5.15 seconds to obtain all nodes *and* have node objects (10k nodes) return (model_query(models.Node) - .options(joinedload('tags')) - .options(joinedload('traits'))) + .options(joinedload(models.Node.tags)) + .options(joinedload(models.Node.traits))) def _get_node_select(): @@ -172,12 +172,23 @@ def _get_node_select(): selectinload(models.Node.traits))) +def _get_deploy_template_select_with_steps(): + """Return a select object for the DeployTemplate joined with steps. + + :returns: a select object. + """ + return sa.select( + models.DeployTemplate + ).options(selectinload(models.DeployTemplate.steps)) + + def _get_deploy_template_query_with_steps(): """Return a query object for the DeployTemplate joined with steps. :returns: a query object. """ - return model_query(models.DeployTemplate).options(joinedload('steps')) + return model_query(models.DeployTemplate).options( + selectinload(models.DeployTemplate.steps)) def model_query(model, *args, **kwargs): @@ -209,6 +220,26 @@ def add_identity_filter(query, value): raise exception.InvalidIdentity(identity=value) +def add_identity_where(op, model, value): + """Adds an identity filter to operation for where method. + + Filters results by ID, if supplied value is a valid integer. + Otherwise attempts to filter results by UUID. + + :param op: Initial operation to add filter to. + i.e. a update or delete statement. + :param model: The SQLAlchemy model to apply. + :param value: Value for filtering results by. + :return: Modified query. + """ + if strutils.is_int_like(value): + return op.where(model.id == value) + elif uuidutils.is_uuid_like(value): + return op.where(model.uuid == value) + else: + raise exception.InvalidIdentity(identity=value) + + def add_port_filter(query, value): """Adds a port-specific filter to a query. @@ -281,7 +312,7 @@ def add_portgroup_filter(query, value): if netutils.is_valid_mac(value): return query.filter_by(address=value) else: - return add_identity_filter(query, value) + return add_identity_where(query, models.Portgroup, value) def add_portgroup_filter_by_node(query, value): @@ -352,8 +383,8 @@ def _paginate_query(model, limit=None, marker=None, sort_key=None, # sets in ORM mode, so we need to explicitly ask for it to be unique # before returning it to the caller. if isinstance(query, sa_orm.Query): - # The classic ORM query object result set which is deprecated - # in advance of SQLAlchemy 2.0. + # The classic "Legacy" ORM query object result set which is + # deprecated in advance of SQLAlchemy 2.0. return query.all() else: # In this case, we have a sqlalchemy.sql.selectable.Select @@ -936,13 +967,14 @@ class Connection(api.Connection): ref.update(values) - # Return the updated node model joined with all relevant fields. - query = _get_node_query_with_all_for_single_node() - query = add_identity_filter(query, node_id) - # FIXME(TheJulia): This entire method needs to be re-written to - # use the proper execution format for SQLAlchemy 2.0. Likely - # A query, independent update, and a re-query on the transaction. - return query.one() + # Return the updated node model joined with all relevant fields. + query = _get_node_select() + query = add_identity_filter(query, node_id) + # FIXME(TheJulia): This entire method needs to be re-written to + # use the proper execution format for SQLAlchemy 2.0. Likely + # A query, independent update, and a re-query on the transaction. + with _session_for_read() as session: + return session.execute(query).one()[0] def get_port_by_id(self, port_id): query = model_query(models.Port).filter_by(id=port_id) @@ -979,7 +1011,7 @@ class Connection(api.Connection): def get_port_list(self, limit=None, marker=None, sort_key=None, sort_dir=None, owner=None, project=None): - query = model_query(models.Port) + query = sa.select(models.Port) if owner: query = add_port_filter_by_node_owner(query, owner) elif project: @@ -990,8 +1022,7 @@ class Connection(api.Connection): def get_ports_by_node_id(self, node_id, limit=None, marker=None, sort_key=None, sort_dir=None, owner=None, project=None): - query = model_query(models.Port) - query = query.filter_by(node_id=node_id) + query = sa.select(models.Port).where(models.Port.node_id == node_id) if owner: query = add_port_filter_by_node_owner(query, owner) elif project: @@ -1002,8 +1033,10 @@ class Connection(api.Connection): def get_ports_by_portgroup_id(self, portgroup_id, limit=None, marker=None, sort_key=None, sort_dir=None, owner=None, project=None): - query = model_query(models.Port) - query = query.filter_by(portgroup_id=portgroup_id) + query = sa.select(models.Port).where( + models.Port.portgroup_id == portgroup_id + ) + if owner: query = add_port_filter_by_node_owner(query, owner) elif project: @@ -1034,7 +1067,6 @@ class Connection(api.Connection): if 'uuid' in values: msg = _("Cannot overwrite UUID for an existing Port.") raise exception.InvalidParameterValue(err=msg) - try: with _session_for_write() as session: query = model_query(models.Port) @@ -1103,7 +1135,7 @@ class Connection(api.Connection): def get_portgroups_by_node_id(self, node_id, limit=None, marker=None, sort_key=None, sort_dir=None, project=None): query = model_query(models.Portgroup) - query = query.filter_by(node_id=node_id) + query = query.where(models.Portgroup.node_id == node_id) if project: query = add_portgroup_filter_by_node_project(query, project) return _paginate_query(models.Portgroup, limit, marker, @@ -1160,34 +1192,40 @@ class Connection(api.Connection): def destroy_portgroup(self, portgroup_id): def portgroup_not_empty(session): """Checks whether the portgroup does not have ports.""" - - query = model_query(models.Port) - query = add_port_filter_by_portgroup(query, portgroup_id) - - return query.count() != 0 + with _session_for_read() as session: + return session.scalar( + sa.select( + sa.func.count(models.Port.id) + ).where(models.Port.portgroup_id == portgroup_id)) != 0 with _session_for_write() as session: if portgroup_not_empty(session): raise exception.PortgroupNotEmpty(portgroup=portgroup_id) - query = model_query(models.Portgroup, session=session) - query = add_identity_filter(query, portgroup_id) + query = sa.delete(models.Portgroup) + query = add_identity_where(query, models.Portgroup, portgroup_id) - count = query.delete() + count = session.execute(query).rowcount if count == 0: raise exception.PortgroupNotFound(portgroup=portgroup_id) def get_chassis_by_id(self, chassis_id): - query = model_query(models.Chassis).filter_by(id=chassis_id) + query = sa.select(models.Chassis).where( + models.Chassis.id == chassis_id) + try: - return query.one() + with _session_for_read() as session: + return session.execute(query).one()[0] except NoResultFound: raise exception.ChassisNotFound(chassis=chassis_id) def get_chassis_by_uuid(self, chassis_uuid): - query = model_query(models.Chassis).filter_by(uuid=chassis_uuid) + query = sa.select(models.Chassis).where( + models.Chassis.uuid == chassis_uuid) + try: - return query.one() + with _session_for_read() as session: + return session.execute(query).one()[0] except NoResultFound: raise exception.ChassisNotFound(chassis=chassis_uuid) @@ -1220,7 +1258,7 @@ class Connection(api.Connection): with _session_for_write(): query = model_query(models.Chassis) - query = add_identity_filter(query, chassis_id) + query = add_identity_where(query, models.Chassis, chassis_id) count = query.update(values) if count != 1: @@ -1276,27 +1314,32 @@ class Connection(api.Connection): def get_conductor(self, hostname, online=True): try: - query = model_query(models.Conductor).filter_by(hostname=hostname) + query = sa.select(models.Conductor).where( + models.Conductor.hostname == hostname) if online is not None: - query = query.filter_by(online=online) - return query.one() + query = query.where(models.Conductor.online == online) + with _session_for_read() as session: + res = session.execute(query).one()[0] + return res except NoResultFound: raise exception.ConductorNotFound(conductor=hostname) @oslo_db_api.retry_on_deadlock def unregister_conductor(self, hostname): - with _session_for_write(): - query = (model_query(models.Conductor) - .filter_by(hostname=hostname, online=True)) - count = query.update({'online': False}) + with _session_for_write() as session: + query = sa.update(models.Conductor).where( + models.Conductor.hostname == hostname, + models.Conductor.online == True).values( # noqa + online=False) + count = session.execute(query).rowcount if count == 0: raise exception.ConductorNotFound(conductor=hostname) @oslo_db_api.retry_on_deadlock def touch_conductor(self, hostname): with _session_for_write(): - query = (model_query(models.Conductor) - .filter_by(hostname=hostname)) + query = model_query(models.Conductor) + query = query.where(models.Conductor.hostname == hostname) # since we're not changing any other field, manually set updated_at # and since we're heartbeating, make sure that online=True count = query.update({'updated_at': timeutils.utcnow(), @@ -1371,7 +1414,7 @@ class Connection(api.Connection): def list_conductor_hardware_interfaces(self, conductor_id): query = (model_query(models.ConductorHardwareInterfaces) - .filter_by(conductor_id=conductor_id)) + .where(models.ConductorHardwareInterfaces.conductor_id == conductor_id)) # noqa return query.all() def list_hardware_type_interfaces(self, hardware_types): @@ -1422,7 +1465,8 @@ class Connection(api.Connection): raise exception.NodeNotFound(node=node_id) def _check_node_exists(self, node_id): - if not model_query(models.Node).filter_by(id=node_id).scalar(): + if not model_query(models.Node).where( + models.Node.id == node_id).scalar(): raise exception.NodeNotFound(node=node_id) @oslo_db_api.retry_on_deadlock @@ -1483,14 +1527,17 @@ class Connection(api.Connection): return model_query(q.exists()).scalar() def get_node_by_port_addresses(self, addresses): - q = _get_node_query_with_all_for_single_node() + q = _get_node_select() q = q.distinct().join(models.Port) q = q.filter(models.Port.address.in_(addresses)) try: # FIXME(TheJulia): This needs to be updated to be # an explicit query to identify the node for SQLAlchemy. - return q.one() + with _session_for_read() as session: + # Always return the first element, since we always + # get a tuple from sqlalchemy. + return session.execute(q).one()[0] except NoResultFound: raise exception.NodeNotFound( _('Node with port addresses %s was not found') @@ -1526,7 +1573,8 @@ class Connection(api.Connection): def get_volume_connectors_by_node_id(self, node_id, limit=None, marker=None, sort_key=None, sort_dir=None, project=None): - query = model_query(models.VolumeConnector).filter_by(node_id=node_id) + query = model_query(models.VolumeConnector).where( + models.VolumeConnector.node_id == node_id) if project: add_volume_conn_filter_by_node_project(query, project) return _paginate_query(models.VolumeConnector, limit, marker, @@ -1594,7 +1642,8 @@ class Connection(api.Connection): sort_key, sort_dir, query) def get_volume_target_by_id(self, db_id): - query = model_query(models.VolumeTarget).filter_by(id=db_id) + query = model_query(models.VolumeTarget).where( + models.VolumeTarget.id == db_id) try: return query.one() except NoResultFound: @@ -1619,7 +1668,8 @@ class Connection(api.Connection): def get_volume_targets_by_volume_id(self, volume_id, limit=None, marker=None, sort_key=None, sort_dir=None, project=None): - query = model_query(models.VolumeTarget).filter_by(volume_id=volume_id) + query = model_query(models.VolumeTarget).where( + models.VolumeTarget.volume_id == volume_id) if project: query = add_volume_target_filter_by_node_project(query, project) return _paginate_query(models.VolumeTarget, limit, marker, sort_key, @@ -2330,30 +2380,29 @@ class Connection(api.Connection): # this does not work with PostgreSQL. query = model_query(models.DeployTemplate) query = add_identity_filter(query, template_id) - try: - ref = query.with_for_update().one() - except NoResultFound: - raise exception.DeployTemplateNotFound( - template=template_id) - + ref = query.with_for_update().one() # First, update non-step columns. steps = values.pop('steps', None) ref.update(values) - # If necessary, update steps. if steps is not None: self._update_deploy_template_steps(session, ref.id, steps) + session.flush() + with _session_for_read() as session: # Return the updated template joined with all relevant fields. - query = _get_deploy_template_query_with_steps() + query = _get_deploy_template_select_with_steps() query = add_identity_filter(query, template_id) - # FIXME(TheJulia): This needs to be fixed for SQLAlchemy 2.0. - return query.one() + return session.execute(query).one()[0] except db_exc.DBDuplicateEntry as e: if 'name' in e.columns: raise exception.DeployTemplateDuplicateName( name=values['name']) raise + except NoResultFound: + # TODO(TheJulia): What would unified core raise?!? + raise exception.DeployTemplateNotFound( + template=template_id) @oslo_db_api.retry_on_deadlock def destroy_deploy_template(self, template_id): @@ -2367,22 +2416,26 @@ class Connection(api.Connection): def _get_deploy_template(self, field, value): """Helper method for retrieving a deploy template.""" - query = (_get_deploy_template_query_with_steps() - .filter_by(**{field: value})) + query = (_get_deploy_template_select_with_steps() + .where(field == value)) try: # FIXME(TheJulia): This needs to be fixed for SQLAlchemy 2.0 - return query.one() + with _session_for_read() as session: + return session.execute(query).one()[0] except NoResultFound: raise exception.DeployTemplateNotFound(template=value) def get_deploy_template_by_id(self, template_id): - return self._get_deploy_template('id', template_id) + return self._get_deploy_template(models.DeployTemplate.id, + template_id) def get_deploy_template_by_uuid(self, template_uuid): - return self._get_deploy_template('uuid', template_uuid) + return self._get_deploy_template(models.DeployTemplate.uuid, + template_uuid) def get_deploy_template_by_name(self, template_name): - return self._get_deploy_template('name', template_name) + return self._get_deploy_template(models.DeployTemplate.name, + template_name) def get_deploy_template_list(self, limit=None, marker=None, sort_key=None, sort_dir=None): @@ -2391,9 +2444,14 @@ class Connection(api.Connection): sort_key, sort_dir, query) def get_deploy_template_list_by_names(self, names): - query = (_get_deploy_template_query_with_steps() - .filter(models.DeployTemplate.name.in_(names))) - return query.all() + query = _get_deploy_template_select_with_steps() + with _session_for_read() as session: + res = session.execute( + query.where( + models.DeployTemplate.name.in_(names) + ) + ).all() + return [r[0] for r in res] @oslo_db_api.retry_on_deadlock def create_node_history(self, values): @@ -2440,7 +2498,7 @@ class Connection(api.Connection): def get_node_history_by_node_id(self, node_id, limit=None, marker=None, sort_key=None, sort_dir=None): query = model_query(models.NodeHistory) - query = query.filter_by(node_id=node_id) + query = query.where(models.NodeHistory.node_id == node_id) return _paginate_query(models.NodeHistory, limit, marker, sort_key, sort_dir, query) @@ -2507,6 +2565,9 @@ class Connection(api.Connection): # Uses input entry list, selects entries matching those ids # then deletes them and does not synchronize the session so # sqlalchemy doesn't do extra un-necessary work. + # NOTE(TheJulia): This is "legacy" syntax, but it is still + # valid and under the hood SQLAlchemy rewrites the form into + # a delete syntax. session.query( models.NodeHistory ).filter( @@ -2525,13 +2586,12 @@ class Connection(api.Connection): # literally have the DB do *all* of the world, so no # client side ops occur. The column is also indexed, # which means this will be an index based response. - # TODO(TheJulia): This might need to be revised for - # SQLAlchemy 2.0 as it should be a scaler select and count - # instead. - return session.query( - models.Node.provision_state - ).filter( - or_( - models.Node.provision_state == v for v in state + return session.scalar( + sa.select( + sa.func.count(models.Node.id) + ).filter( + or_( + models.Node.provision_state == v for v in state + ) ) - ).count() + ) diff --git a/ironic/db/sqlalchemy/models.py b/ironic/db/sqlalchemy/models.py index 3631d83a9..fad29f095 100644 --- a/ironic/db/sqlalchemy/models.py +++ b/ironic/db/sqlalchemy/models.py @@ -414,6 +414,10 @@ class DeployTemplate(Base): uuid = Column(String(36)) name = Column(String(255), nullable=False) extra = Column(db_types.JsonEncodedDict) + steps: orm.Mapped[List['DeployTemplateStep']] = orm.relationship( # noqa + "DeployTemplateStep", + back_populates="deploy_template", + lazy="selectin") class DeployTemplateStep(Base): @@ -434,7 +438,6 @@ class DeployTemplateStep(Base): priority = Column(Integer, nullable=False) deploy_template = orm.relationship( "DeployTemplate", - backref='steps', primaryjoin=( 'and_(DeployTemplateStep.deploy_template_id == ' 'DeployTemplate.id)'), |