summaryrefslogtreecommitdiff
path: root/ironic
diff options
context:
space:
mode:
authorJulia Kreger <juliaashleykreger@gmail.com>2022-09-13 14:58:15 -0700
committerJulia Kreger <juliaashleykreger@gmail.com>2022-10-14 07:58:27 -0700
commitc2df29e52f44e39c9fa3cc6ef0820749db1c8647 (patch)
tree7869d05cf829ed147e0a21f3b0643df75daeee82 /ironic
parent49e085583dec81c63d19f80a9ba067e38d8043ae (diff)
downloadironic-c2df29e52f44e39c9fa3cc6ef0820749db1c8647.tar.gz
Phase 2 - SQLAlchemy 2.0 Compatability
* Changed common exception imports from SQLAlchemy for ORM query types which are now originated from the main exception definition set. * Changed base join option usage to use objects instead of labels, and defaulted all multi-row result sets to return data using "selectinload" as opposed to operating with a join query to avoid need to de-duplicate all result sets. * Changed DeployTemplates to utilize objects instead of field names for queries, and updated the associated join ORM model's relationship record between DeployTemplate and DeployTemplateSteps. * Changed Ports, Chassis, Conductor, Volume Target/Connector queries to lean towards use of select/update/delete queries as opposed to ORM queries. Most of these changes revolved around references of field names as opposed to model objects. * This change also labels a few lines as "noqa", which is a result of the style check rules getting triggered on statements as needed for SQLAlchemy. Change-Id: I651ec4b50c79be6aa8c798ee27957ed720a578d8
Diffstat (limited to 'ironic')
-rw-r--r--ironic/db/sqlalchemy/api.py218
-rw-r--r--ironic/db/sqlalchemy/models.py5
2 files changed, 143 insertions, 80 deletions
diff --git a/ironic/db/sqlalchemy/api.py b/ironic/db/sqlalchemy/api.py
index eadfce776..b05af3637 100644
--- a/ironic/db/sqlalchemy/api.py
+++ b/ironic/db/sqlalchemy/api.py
@@ -33,7 +33,7 @@ from oslo_utils import uuidutils
from osprofiler import sqlalchemy as osp_sqlalchemy
import sqlalchemy as sa
from sqlalchemy import or_
-from sqlalchemy.orm.exc import NoResultFound, MultipleResultsFound
+from sqlalchemy.exc import NoResultFound, MultipleResultsFound
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Load
from sqlalchemy.orm import selectinload
@@ -139,8 +139,8 @@ def _get_node_query_with_all_for_single_node():
# 2.43 seconds to obtain all nodes from SQLAlchemy (10k nodes)
# 5.15 seconds to obtain all nodes *and* have node objects (10k nodes)
return (model_query(models.Node)
- .options(joinedload('tags'))
- .options(joinedload('traits')))
+ .options(joinedload(models.Node.tags))
+ .options(joinedload(models.Node.traits)))
def _get_node_select():
@@ -172,12 +172,23 @@ def _get_node_select():
selectinload(models.Node.traits)))
+def _get_deploy_template_select_with_steps():
+ """Return a select object for the DeployTemplate joined with steps.
+
+ :returns: a select object.
+ """
+ return sa.select(
+ models.DeployTemplate
+ ).options(selectinload(models.DeployTemplate.steps))
+
+
def _get_deploy_template_query_with_steps():
"""Return a query object for the DeployTemplate joined with steps.
:returns: a query object.
"""
- return model_query(models.DeployTemplate).options(joinedload('steps'))
+ return model_query(models.DeployTemplate).options(
+ selectinload(models.DeployTemplate.steps))
def model_query(model, *args, **kwargs):
@@ -209,6 +220,26 @@ def add_identity_filter(query, value):
raise exception.InvalidIdentity(identity=value)
+def add_identity_where(op, model, value):
+ """Adds an identity filter to operation for where method.
+
+ Filters results by ID, if supplied value is a valid integer.
+ Otherwise attempts to filter results by UUID.
+
+ :param op: Initial operation to add filter to.
+ i.e. a update or delete statement.
+ :param model: The SQLAlchemy model to apply.
+ :param value: Value for filtering results by.
+ :return: Modified query.
+ """
+ if strutils.is_int_like(value):
+ return op.where(model.id == value)
+ elif uuidutils.is_uuid_like(value):
+ return op.where(model.uuid == value)
+ else:
+ raise exception.InvalidIdentity(identity=value)
+
+
def add_port_filter(query, value):
"""Adds a port-specific filter to a query.
@@ -281,7 +312,7 @@ def add_portgroup_filter(query, value):
if netutils.is_valid_mac(value):
return query.filter_by(address=value)
else:
- return add_identity_filter(query, value)
+ return add_identity_where(query, models.Portgroup, value)
def add_portgroup_filter_by_node(query, value):
@@ -352,8 +383,8 @@ def _paginate_query(model, limit=None, marker=None, sort_key=None,
# sets in ORM mode, so we need to explicitly ask for it to be unique
# before returning it to the caller.
if isinstance(query, sa_orm.Query):
- # The classic ORM query object result set which is deprecated
- # in advance of SQLAlchemy 2.0.
+ # The classic "Legacy" ORM query object result set which is
+ # deprecated in advance of SQLAlchemy 2.0.
return query.all()
else:
# In this case, we have a sqlalchemy.sql.selectable.Select
@@ -936,13 +967,14 @@ class Connection(api.Connection):
ref.update(values)
- # Return the updated node model joined with all relevant fields.
- query = _get_node_query_with_all_for_single_node()
- query = add_identity_filter(query, node_id)
- # FIXME(TheJulia): This entire method needs to be re-written to
- # use the proper execution format for SQLAlchemy 2.0. Likely
- # A query, independent update, and a re-query on the transaction.
- return query.one()
+ # Return the updated node model joined with all relevant fields.
+ query = _get_node_select()
+ query = add_identity_filter(query, node_id)
+ # FIXME(TheJulia): This entire method needs to be re-written to
+ # use the proper execution format for SQLAlchemy 2.0. Likely
+ # A query, independent update, and a re-query on the transaction.
+ with _session_for_read() as session:
+ return session.execute(query).one()[0]
def get_port_by_id(self, port_id):
query = model_query(models.Port).filter_by(id=port_id)
@@ -979,7 +1011,7 @@ class Connection(api.Connection):
def get_port_list(self, limit=None, marker=None,
sort_key=None, sort_dir=None, owner=None,
project=None):
- query = model_query(models.Port)
+ query = sa.select(models.Port)
if owner:
query = add_port_filter_by_node_owner(query, owner)
elif project:
@@ -990,8 +1022,7 @@ class Connection(api.Connection):
def get_ports_by_node_id(self, node_id, limit=None, marker=None,
sort_key=None, sort_dir=None, owner=None,
project=None):
- query = model_query(models.Port)
- query = query.filter_by(node_id=node_id)
+ query = sa.select(models.Port).where(models.Port.node_id == node_id)
if owner:
query = add_port_filter_by_node_owner(query, owner)
elif project:
@@ -1002,8 +1033,10 @@ class Connection(api.Connection):
def get_ports_by_portgroup_id(self, portgroup_id, limit=None, marker=None,
sort_key=None, sort_dir=None, owner=None,
project=None):
- query = model_query(models.Port)
- query = query.filter_by(portgroup_id=portgroup_id)
+ query = sa.select(models.Port).where(
+ models.Port.portgroup_id == portgroup_id
+ )
+
if owner:
query = add_port_filter_by_node_owner(query, owner)
elif project:
@@ -1034,7 +1067,6 @@ class Connection(api.Connection):
if 'uuid' in values:
msg = _("Cannot overwrite UUID for an existing Port.")
raise exception.InvalidParameterValue(err=msg)
-
try:
with _session_for_write() as session:
query = model_query(models.Port)
@@ -1103,7 +1135,7 @@ class Connection(api.Connection):
def get_portgroups_by_node_id(self, node_id, limit=None, marker=None,
sort_key=None, sort_dir=None, project=None):
query = model_query(models.Portgroup)
- query = query.filter_by(node_id=node_id)
+ query = query.where(models.Portgroup.node_id == node_id)
if project:
query = add_portgroup_filter_by_node_project(query, project)
return _paginate_query(models.Portgroup, limit, marker,
@@ -1160,34 +1192,40 @@ class Connection(api.Connection):
def destroy_portgroup(self, portgroup_id):
def portgroup_not_empty(session):
"""Checks whether the portgroup does not have ports."""
-
- query = model_query(models.Port)
- query = add_port_filter_by_portgroup(query, portgroup_id)
-
- return query.count() != 0
+ with _session_for_read() as session:
+ return session.scalar(
+ sa.select(
+ sa.func.count(models.Port.id)
+ ).where(models.Port.portgroup_id == portgroup_id)) != 0
with _session_for_write() as session:
if portgroup_not_empty(session):
raise exception.PortgroupNotEmpty(portgroup=portgroup_id)
- query = model_query(models.Portgroup, session=session)
- query = add_identity_filter(query, portgroup_id)
+ query = sa.delete(models.Portgroup)
+ query = add_identity_where(query, models.Portgroup, portgroup_id)
- count = query.delete()
+ count = session.execute(query).rowcount
if count == 0:
raise exception.PortgroupNotFound(portgroup=portgroup_id)
def get_chassis_by_id(self, chassis_id):
- query = model_query(models.Chassis).filter_by(id=chassis_id)
+ query = sa.select(models.Chassis).where(
+ models.Chassis.id == chassis_id)
+
try:
- return query.one()
+ with _session_for_read() as session:
+ return session.execute(query).one()[0]
except NoResultFound:
raise exception.ChassisNotFound(chassis=chassis_id)
def get_chassis_by_uuid(self, chassis_uuid):
- query = model_query(models.Chassis).filter_by(uuid=chassis_uuid)
+ query = sa.select(models.Chassis).where(
+ models.Chassis.uuid == chassis_uuid)
+
try:
- return query.one()
+ with _session_for_read() as session:
+ return session.execute(query).one()[0]
except NoResultFound:
raise exception.ChassisNotFound(chassis=chassis_uuid)
@@ -1220,7 +1258,7 @@ class Connection(api.Connection):
with _session_for_write():
query = model_query(models.Chassis)
- query = add_identity_filter(query, chassis_id)
+ query = add_identity_where(query, models.Chassis, chassis_id)
count = query.update(values)
if count != 1:
@@ -1276,27 +1314,32 @@ class Connection(api.Connection):
def get_conductor(self, hostname, online=True):
try:
- query = model_query(models.Conductor).filter_by(hostname=hostname)
+ query = sa.select(models.Conductor).where(
+ models.Conductor.hostname == hostname)
if online is not None:
- query = query.filter_by(online=online)
- return query.one()
+ query = query.where(models.Conductor.online == online)
+ with _session_for_read() as session:
+ res = session.execute(query).one()[0]
+ return res
except NoResultFound:
raise exception.ConductorNotFound(conductor=hostname)
@oslo_db_api.retry_on_deadlock
def unregister_conductor(self, hostname):
- with _session_for_write():
- query = (model_query(models.Conductor)
- .filter_by(hostname=hostname, online=True))
- count = query.update({'online': False})
+ with _session_for_write() as session:
+ query = sa.update(models.Conductor).where(
+ models.Conductor.hostname == hostname,
+ models.Conductor.online == True).values( # noqa
+ online=False)
+ count = session.execute(query).rowcount
if count == 0:
raise exception.ConductorNotFound(conductor=hostname)
@oslo_db_api.retry_on_deadlock
def touch_conductor(self, hostname):
with _session_for_write():
- query = (model_query(models.Conductor)
- .filter_by(hostname=hostname))
+ query = model_query(models.Conductor)
+ query = query.where(models.Conductor.hostname == hostname)
# since we're not changing any other field, manually set updated_at
# and since we're heartbeating, make sure that online=True
count = query.update({'updated_at': timeutils.utcnow(),
@@ -1371,7 +1414,7 @@ class Connection(api.Connection):
def list_conductor_hardware_interfaces(self, conductor_id):
query = (model_query(models.ConductorHardwareInterfaces)
- .filter_by(conductor_id=conductor_id))
+ .where(models.ConductorHardwareInterfaces.conductor_id == conductor_id)) # noqa
return query.all()
def list_hardware_type_interfaces(self, hardware_types):
@@ -1422,7 +1465,8 @@ class Connection(api.Connection):
raise exception.NodeNotFound(node=node_id)
def _check_node_exists(self, node_id):
- if not model_query(models.Node).filter_by(id=node_id).scalar():
+ if not model_query(models.Node).where(
+ models.Node.id == node_id).scalar():
raise exception.NodeNotFound(node=node_id)
@oslo_db_api.retry_on_deadlock
@@ -1483,14 +1527,17 @@ class Connection(api.Connection):
return model_query(q.exists()).scalar()
def get_node_by_port_addresses(self, addresses):
- q = _get_node_query_with_all_for_single_node()
+ q = _get_node_select()
q = q.distinct().join(models.Port)
q = q.filter(models.Port.address.in_(addresses))
try:
# FIXME(TheJulia): This needs to be updated to be
# an explicit query to identify the node for SQLAlchemy.
- return q.one()
+ with _session_for_read() as session:
+ # Always return the first element, since we always
+ # get a tuple from sqlalchemy.
+ return session.execute(q).one()[0]
except NoResultFound:
raise exception.NodeNotFound(
_('Node with port addresses %s was not found')
@@ -1526,7 +1573,8 @@ class Connection(api.Connection):
def get_volume_connectors_by_node_id(self, node_id, limit=None,
marker=None, sort_key=None,
sort_dir=None, project=None):
- query = model_query(models.VolumeConnector).filter_by(node_id=node_id)
+ query = model_query(models.VolumeConnector).where(
+ models.VolumeConnector.node_id == node_id)
if project:
add_volume_conn_filter_by_node_project(query, project)
return _paginate_query(models.VolumeConnector, limit, marker,
@@ -1594,7 +1642,8 @@ class Connection(api.Connection):
sort_key, sort_dir, query)
def get_volume_target_by_id(self, db_id):
- query = model_query(models.VolumeTarget).filter_by(id=db_id)
+ query = model_query(models.VolumeTarget).where(
+ models.VolumeTarget.id == db_id)
try:
return query.one()
except NoResultFound:
@@ -1619,7 +1668,8 @@ class Connection(api.Connection):
def get_volume_targets_by_volume_id(self, volume_id, limit=None,
marker=None, sort_key=None,
sort_dir=None, project=None):
- query = model_query(models.VolumeTarget).filter_by(volume_id=volume_id)
+ query = model_query(models.VolumeTarget).where(
+ models.VolumeTarget.volume_id == volume_id)
if project:
query = add_volume_target_filter_by_node_project(query, project)
return _paginate_query(models.VolumeTarget, limit, marker, sort_key,
@@ -2330,30 +2380,29 @@ class Connection(api.Connection):
# this does not work with PostgreSQL.
query = model_query(models.DeployTemplate)
query = add_identity_filter(query, template_id)
- try:
- ref = query.with_for_update().one()
- except NoResultFound:
- raise exception.DeployTemplateNotFound(
- template=template_id)
-
+ ref = query.with_for_update().one()
# First, update non-step columns.
steps = values.pop('steps', None)
ref.update(values)
-
# If necessary, update steps.
if steps is not None:
self._update_deploy_template_steps(session, ref.id, steps)
+ session.flush()
+ with _session_for_read() as session:
# Return the updated template joined with all relevant fields.
- query = _get_deploy_template_query_with_steps()
+ query = _get_deploy_template_select_with_steps()
query = add_identity_filter(query, template_id)
- # FIXME(TheJulia): This needs to be fixed for SQLAlchemy 2.0.
- return query.one()
+ return session.execute(query).one()[0]
except db_exc.DBDuplicateEntry as e:
if 'name' in e.columns:
raise exception.DeployTemplateDuplicateName(
name=values['name'])
raise
+ except NoResultFound:
+ # TODO(TheJulia): What would unified core raise?!?
+ raise exception.DeployTemplateNotFound(
+ template=template_id)
@oslo_db_api.retry_on_deadlock
def destroy_deploy_template(self, template_id):
@@ -2367,22 +2416,26 @@ class Connection(api.Connection):
def _get_deploy_template(self, field, value):
"""Helper method for retrieving a deploy template."""
- query = (_get_deploy_template_query_with_steps()
- .filter_by(**{field: value}))
+ query = (_get_deploy_template_select_with_steps()
+ .where(field == value))
try:
# FIXME(TheJulia): This needs to be fixed for SQLAlchemy 2.0
- return query.one()
+ with _session_for_read() as session:
+ return session.execute(query).one()[0]
except NoResultFound:
raise exception.DeployTemplateNotFound(template=value)
def get_deploy_template_by_id(self, template_id):
- return self._get_deploy_template('id', template_id)
+ return self._get_deploy_template(models.DeployTemplate.id,
+ template_id)
def get_deploy_template_by_uuid(self, template_uuid):
- return self._get_deploy_template('uuid', template_uuid)
+ return self._get_deploy_template(models.DeployTemplate.uuid,
+ template_uuid)
def get_deploy_template_by_name(self, template_name):
- return self._get_deploy_template('name', template_name)
+ return self._get_deploy_template(models.DeployTemplate.name,
+ template_name)
def get_deploy_template_list(self, limit=None, marker=None,
sort_key=None, sort_dir=None):
@@ -2391,9 +2444,14 @@ class Connection(api.Connection):
sort_key, sort_dir, query)
def get_deploy_template_list_by_names(self, names):
- query = (_get_deploy_template_query_with_steps()
- .filter(models.DeployTemplate.name.in_(names)))
- return query.all()
+ query = _get_deploy_template_select_with_steps()
+ with _session_for_read() as session:
+ res = session.execute(
+ query.where(
+ models.DeployTemplate.name.in_(names)
+ )
+ ).all()
+ return [r[0] for r in res]
@oslo_db_api.retry_on_deadlock
def create_node_history(self, values):
@@ -2440,7 +2498,7 @@ class Connection(api.Connection):
def get_node_history_by_node_id(self, node_id, limit=None, marker=None,
sort_key=None, sort_dir=None):
query = model_query(models.NodeHistory)
- query = query.filter_by(node_id=node_id)
+ query = query.where(models.NodeHistory.node_id == node_id)
return _paginate_query(models.NodeHistory, limit, marker,
sort_key, sort_dir, query)
@@ -2507,6 +2565,9 @@ class Connection(api.Connection):
# Uses input entry list, selects entries matching those ids
# then deletes them and does not synchronize the session so
# sqlalchemy doesn't do extra un-necessary work.
+ # NOTE(TheJulia): This is "legacy" syntax, but it is still
+ # valid and under the hood SQLAlchemy rewrites the form into
+ # a delete syntax.
session.query(
models.NodeHistory
).filter(
@@ -2525,13 +2586,12 @@ class Connection(api.Connection):
# literally have the DB do *all* of the world, so no
# client side ops occur. The column is also indexed,
# which means this will be an index based response.
- # TODO(TheJulia): This might need to be revised for
- # SQLAlchemy 2.0 as it should be a scaler select and count
- # instead.
- return session.query(
- models.Node.provision_state
- ).filter(
- or_(
- models.Node.provision_state == v for v in state
+ return session.scalar(
+ sa.select(
+ sa.func.count(models.Node.id)
+ ).filter(
+ or_(
+ models.Node.provision_state == v for v in state
+ )
)
- ).count()
+ )
diff --git a/ironic/db/sqlalchemy/models.py b/ironic/db/sqlalchemy/models.py
index 3631d83a9..fad29f095 100644
--- a/ironic/db/sqlalchemy/models.py
+++ b/ironic/db/sqlalchemy/models.py
@@ -414,6 +414,10 @@ class DeployTemplate(Base):
uuid = Column(String(36))
name = Column(String(255), nullable=False)
extra = Column(db_types.JsonEncodedDict)
+ steps: orm.Mapped[List['DeployTemplateStep']] = orm.relationship( # noqa
+ "DeployTemplateStep",
+ back_populates="deploy_template",
+ lazy="selectin")
class DeployTemplateStep(Base):
@@ -434,7 +438,6 @@ class DeployTemplateStep(Base):
priority = Column(Integer, nullable=False)
deploy_template = orm.relationship(
"DeployTemplate",
- backref='steps',
primaryjoin=(
'and_(DeployTemplateStep.deploy_template_id == '
'DeployTemplate.id)'),