diff options
Diffstat (limited to 'nova/db/sqlalchemy/api.py')
-rw-r--r-- | nova/db/sqlalchemy/api.py | 57 |
1 files changed, 42 insertions, 15 deletions
diff --git a/nova/db/sqlalchemy/api.py b/nova/db/sqlalchemy/api.py index 15c583f458..278b309dfe 100644 --- a/nova/db/sqlalchemy/api.py +++ b/nova/db/sqlalchemy/api.py @@ -36,6 +36,7 @@ from sqlalchemy.exc import NoSuchTableError from sqlalchemy import Integer from sqlalchemy import MetaData from sqlalchemy import or_ +from sqlalchemy.orm import contains_eager from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload_all from sqlalchemy.schema import Table @@ -1690,7 +1691,7 @@ def instance_get_all_by_filters(context, filters, sort_key, sort_dir, # For other filters that don't match this, we will do regexp matching exact_match_filter_names = ['project_id', 'user_id', 'image_ref', 'vm_state', 'instance_type_id', 'uuid', - 'metadata'] + 'metadata', 'task_state'] # Filter the query query_prefix = exact_filter(query_prefix, models.Instance, @@ -3231,6 +3232,8 @@ def security_group_rule_get_by_security_group(context, security_group_id, filter_by(parent_group_id=security_group_id).\ options(joinedload_all('grantee_group.instances.' 'system_metadata')).\ + options(joinedload('grantee_group.instances.' + 'info_cache')).\ all() @@ -3556,7 +3559,7 @@ def instance_type_create(context, values): pass try: instance_type_get_by_flavor_id(context, values['flavorid'], - session) + read_deleted='no', session=session) raise exception.InstanceTypeIdExists(flavor_id=values['flavorid']) except exception.FlavorNotFound: pass @@ -3598,9 +3601,16 @@ def _dict_with_extra_specs(inst_type_query): def _instance_type_get_query(context, session=None, read_deleted=None): - return model_query(context, models.InstanceTypes, session=session, + query = model_query(context, models.InstanceTypes, session=session, read_deleted=read_deleted).\ - options(joinedload('extra_specs')) + options(joinedload('extra_specs')) + if not context.is_admin: + the_filter = [models.InstanceTypes.is_public == True] + the_filter.extend([ + models.InstanceTypes.projects.any(project_id=context.project_id) + ]) + query = query.filter(or_(*the_filter)) + return query @require_context @@ -3675,9 +3685,11 @@ def instance_type_get_by_name(context, name, session=None): @require_context -def instance_type_get_by_flavor_id(context, flavor_id, session=None): +def instance_type_get_by_flavor_id(context, flavor_id, read_deleted, + session=None): """Returns a dict describing specific flavor_id.""" - result = _instance_type_get_query(context, session=session).\ + result = _instance_type_get_query(context, read_deleted=read_deleted, + session=session).\ filter_by(flavorid=flavor_id).\ first() @@ -3727,7 +3739,7 @@ def instance_type_access_add(context, flavor_id, project_id): session = get_session() with session.begin(): instance_type_ref = instance_type_get_by_flavor_id(context, flavor_id, - session=session) + read_deleted='no', session=session) instance_type_id = instance_type_ref['id'] access_ref = _instance_type_access_query(context, session=session).\ filter_by(instance_type_id=instance_type_id).\ @@ -3750,7 +3762,7 @@ def instance_type_access_remove(context, flavor_id, project_id): session = get_session() with session.begin(): instance_type_ref = instance_type_get_by_flavor_id(context, flavor_id, - session=session) + read_deleted='no', session=session) instance_type_id = instance_type_ref['id'] count = _instance_type_access_query(context, session=session).\ filter_by(instance_type_id=instance_type_id).\ @@ -4346,8 +4358,16 @@ def aggregate_get(context, aggregate_id): @require_admin_context def aggregate_get_by_host(context, host, key=None): - query = _aggregate_get_query(context, models.Aggregate, - models.AggregateHost.host, host) + """Return rows that match host (mandatory) and metadata key (optional). + + :param host matches host, and is required. + :param key Matches metadata key, if not None. + """ + query = model_query(context, models.Aggregate) + query = query.options(joinedload('_hosts')) + query = query.options(joinedload('_metadata')) + query = query.join('_hosts') + query = query.filter(models.AggregateHost.host == host) if key: query = query.join("_metadata").filter( @@ -4357,13 +4377,16 @@ def aggregate_get_by_host(context, host, key=None): @require_admin_context def aggregate_metadata_get_by_host(context, host, key=None): - query = model_query(context, models.Aggregate).join( - "_hosts").filter(models.AggregateHost.host == host).join( - "_metadata") + query = model_query(context, models.Aggregate) + query = query.join("_hosts") + query = query.join("_metadata") + query = query.filter(models.AggregateHost.host == host) + query = query.options(contains_eager("_metadata")) if key: query = query.filter(models.AggregateMetadata.key == key) rows = query.all() + metadata = collections.defaultdict(set) for agg in rows: for kv in agg._metadata: @@ -4373,9 +4396,13 @@ def aggregate_metadata_get_by_host(context, host, key=None): @require_admin_context def aggregate_host_get_by_metadata_key(context, key): - query = model_query(context, models.Aggregate).join( - "_metadata").filter(models.AggregateMetadata.key == key) + query = model_query(context, models.Aggregate) + query = query.join("_metadata") + query = query.filter(models.AggregateMetadata.key == key) + query = query.options(contains_eager("_metadata")) + query = query.options(joinedload("_hosts")) rows = query.all() + metadata = collections.defaultdict(set) for agg in rows: for agghost in agg._hosts: |