# orm/mapper.py # Copyright (C) 2005, 2006, 2007, 2008 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """Defines the [sqlalchemy.orm.mapper#Mapper] class, the central configurational unit which associates a class with a database table. This is a semi-private module; the main configurational API of the ORM is available in [sqlalchemy.orm#]. """ import weakref from itertools import chain from sqlalchemy import sql, util, exceptions, logging from sqlalchemy.sql import expression, visitors, operators, util as sqlutil from sqlalchemy.orm import sync, attributes from sqlalchemy.orm.interfaces import MapperProperty, EXT_CONTINUE, PropComparator from sqlalchemy.orm.util import has_identity, _state_has_identity, _is_mapped_class, has_mapper, \ _state_mapper, class_mapper, object_mapper, _class_to_mapper,\ ExtensionCarrier, state_str, instance_str __all__ = ['Mapper', 'class_mapper', 'object_mapper', '_mapper_registry'] _mapper_registry = weakref.WeakKeyDictionary() _new_mappers = False _already_compiling = False # a list of MapperExtensions that will be installed in all mappers by default global_extensions = [] # a constant returned by _get_attr_by_column to indicate # this mapper is not handling an attribute for a particular # column NO_ATTRIBUTE = util.symbol('NO_ATTRIBUTE') # lock used to synchronize the "mapper compile" step _COMPILE_MUTEX = util.threading.RLock() # initialize these lazily ColumnProperty = None SynonymProperty = None ComparableProperty = None _expire_state = None class Mapper(object): """Define the correlation of class attributes to database table columns. Instances of this class should be constructed via the [sqlalchemy.orm#mapper()] function. """ def __init__(self, class_, local_table, properties = None, primary_key = None, non_primary = False, inherits = None, inherit_condition = None, inherit_foreign_keys = None, extension = None, order_by = False, allow_column_override = False, entity_name = None, always_refresh = False, version_id_col = None, polymorphic_on=None, _polymorphic_map=None, polymorphic_identity=None, polymorphic_fetch=None, concrete=False, select_table=None, with_polymorphic=None, allow_null_pks=False, batch=True, column_prefix=None, include_properties=None, exclude_properties=None, eager_defaults=False): """Construct a new mapper. Mappers are normally constructed via the [sqlalchemy.orm#mapper()] function. See for details. """ self.class_ = class_ self.entity_name = entity_name self.primary_key_argument = primary_key self.non_primary = non_primary self.order_by = order_by self.always_refresh = always_refresh self.version_id_col = version_id_col self.concrete = concrete self.single = False self.inherits = inherits self.local_table = local_table self.inherit_condition = inherit_condition self.inherit_foreign_keys = inherit_foreign_keys self.extension = extension self._init_properties = properties or {} self.allow_column_override = allow_column_override self.allow_null_pks = allow_null_pks self.delete_orphans = [] self.batch = batch self.eager_defaults = eager_defaults self.column_prefix = column_prefix self.polymorphic_on = polymorphic_on self._eager_loaders = util.Set() self._dependency_processors = [] self._clause_adapter = None self._requires_row_aliasing = False self.__inherits_equated_pairs = None if not issubclass(class_, object): raise exceptions.ArgumentError("Class '%s' is not a new-style class" % class_.__name__) self.select_table = select_table if select_table: if with_polymorphic: raise exceptions.ArgumentError("select_table can't be used with with_polymorphic (they define conflicting settings)") self.with_polymorphic = ('*', select_table) else: if with_polymorphic == '*': self.with_polymorphic = ('*', None) elif isinstance(with_polymorphic, (tuple, list)): if isinstance(with_polymorphic[0], (basestring, tuple, list)): self.with_polymorphic = with_polymorphic else: self.with_polymorphic = (with_polymorphic, None) elif with_polymorphic is not None: raise exceptions.ArgumentError("Invalid setting for with_polymorphic") else: self.with_polymorphic = None if isinstance(self.local_table, expression._SelectBaseMixin): util.warn("mapper %s creating an alias for the given selectable - use Class attributes for queries." % self) self.local_table = self.local_table.alias() if self.with_polymorphic and isinstance(self.with_polymorphic[1], expression._SelectBaseMixin): self.with_polymorphic[1] = self.with_polymorphic[1].alias() # our 'polymorphic identity', a string name that when located in a result set row # indicates this Mapper should be used to construct the object instance for that row. self.polymorphic_identity = polymorphic_identity if polymorphic_fetch not in (None, 'union', 'select', 'deferred'): raise exceptions.ArgumentError("Invalid option for 'polymorphic_fetch': '%s'" % polymorphic_fetch) if polymorphic_fetch is None: self.polymorphic_fetch = (self.with_polymorphic is None) and 'select' or 'union' else: self.polymorphic_fetch = polymorphic_fetch # a dictionary of 'polymorphic identity' names, associating those names with # Mappers that will be used to construct object instances upon a select operation. if _polymorphic_map is None: self.polymorphic_map = {} else: self.polymorphic_map = _polymorphic_map self.columns = self.c = util.OrderedProperties() self.include_properties = include_properties self.exclude_properties = exclude_properties # a set of all mappers which inherit from this one. self._inheriting_mappers = util.Set() self.__props_init = False self.__should_log_info = logging.is_info_enabled(self.logger) self.__should_log_debug = logging.is_debug_enabled(self.logger) self.__compile_class() self.__compile_inheritance() self.__compile_extensions() self.__compile_properties() self.__compile_pks() global _new_mappers _new_mappers = True self.__log("constructed") def __log(self, msg): if self.__should_log_info: self.logger.info("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "") + ") " + msg) def __log_debug(self, msg): if self.__should_log_debug: self.logger.debug("(" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "") + ") " + msg) def _is_orphan(self, obj): o = False for mapper in self.iterate_to_root(): for (key,klass) in mapper.delete_orphans: if attributes.has_parent(klass, obj, key, optimistic=has_identity(obj)): return False o = o or bool(mapper.delete_orphans) return o def get_property(self, key, resolve_synonyms=False, raiseerr=True): """return a MapperProperty associated with the given key.""" self.compile() return self._get_property(key, resolve_synonyms=resolve_synonyms, raiseerr=raiseerr) def _get_property(self, key, resolve_synonyms=False, raiseerr=True): """private in-compilation version of get_property().""" prop = self.__props.get(key, None) if resolve_synonyms: while isinstance(prop, SynonymProperty): prop = self.__props.get(prop.name, None) if prop is None and raiseerr: raise exceptions.InvalidRequestError("Mapper '%s' has no property '%s'" % (str(self), key)) return prop def iterate_properties(self): self.compile() return self.__props.itervalues() iterate_properties = property(iterate_properties, doc="returns an iterator of all MapperProperty objects.") def __adjust_wp_selectable(self, spec=None, selectable=False): """given a with_polymorphic() argument, resolve it against this mapper's with_polymorphic setting""" isdefault = False if self.with_polymorphic: isdefault = not spec and selectable is False if not spec: spec = self.with_polymorphic[0] if selectable is False: selectable = self.with_polymorphic[1] return spec, selectable, isdefault def __mappers_from_spec(self, spec, selectable): """given a with_polymorphic() argument, return the set of mappers it represents. Trims the list of mappers to just those represented within the given selectable, if present. This helps some more legacy-ish mappings. """ if spec == '*': mappers = list(self.polymorphic_iterator()) elif spec: mappers = [_class_to_mapper(m) for m in util.to_list(spec)] else: mappers = [] if selectable: tables = util.Set(sqlutil.find_tables(selectable)) mappers = [m for m in mappers if m.local_table in tables] return mappers __mappers_from_spec = util.conditional_cache_decorator(__mappers_from_spec) def __selectable_from_mappers(self, mappers): """given a list of mappers (assumed to be within this mapper's inheritance hierarchy), construct an outerjoin amongst those mapper's mapped tables. """ from_obj = self.mapped_table for m in mappers: if m is self: continue if m.concrete: raise exceptions.InvalidRequestError("'with_polymorphic()' requires 'selectable' argument when concrete-inheriting mappers are used.") elif not m.single: from_obj = from_obj.outerjoin(m.local_table, m.inherit_condition) return from_obj __selectable_from_mappers = util.conditional_cache_decorator(__selectable_from_mappers) def _with_polymorphic_mappers(self, spec=None, selectable=False): spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable) return self.__mappers_from_spec(spec, selectable, cache=isdefault) def _with_polymorphic_selectable(self, spec=None, selectable=False): spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable) if selectable: return selectable else: return self.__selectable_from_mappers(self.__mappers_from_spec(spec, selectable, cache=isdefault), cache=isdefault) def _with_polymorphic_args(self, spec=None, selectable=False): spec, selectable, isdefault = self.__adjust_wp_selectable(spec, selectable) mappers = self.__mappers_from_spec(spec, selectable, cache=isdefault) if selectable: return mappers, selectable else: return mappers, self.__selectable_from_mappers(mappers, cache=isdefault) def _iterate_polymorphic_properties(self, spec=None, selectable=False): return iter(util.OrderedSet( chain(*[list(mapper.iterate_properties) for mapper in [self] + self._with_polymorphic_mappers(spec, selectable)]) )) def properties(self): raise NotImplementedError("Public collection of MapperProperty objects is provided by the get_property() and iterate_properties accessors.") properties = property(properties) def compiled(self): """return True if this mapper is compiled""" return self.__props_init compiled = property(compiled) def dispose(self): # disaable any attribute-based compilation self.__props_init = True try: del self.class_.c except AttributeError: pass if not self.non_primary and self.entity_name in self._class_state.mappers: del self._class_state.mappers[self.entity_name] if not self._class_state.mappers: attributes.unregister_class(self.class_) def compile(self): """Compile this mapper and all other non-compiled mappers. This method checks the local compiled status as well as for any new mappers that have been defined, and is safe to call repeatedly. """ global _new_mappers if self.__props_init and not _new_mappers: return self _COMPILE_MUTEX.acquire() try: global _already_compiling if _already_compiling: self.__initialize_properties() return _already_compiling = True try: # double-check inside mutex if self.__props_init and not _new_mappers: return self # initialize properties on all mappers for mapper in list(_mapper_registry): if not mapper.__props_init: mapper.__initialize_properties() _new_mappers = False return self finally: _already_compiling = False finally: _COMPILE_MUTEX.release() def __initialize_properties(self): """Call the ``init()`` method on all ``MapperProperties`` attached to this mapper. This is a deferred configuration step which is intended to execute once all mappers have been constructed. """ self.__log("__initialize_properties() started") l = [(key, prop) for key, prop in self.__props.iteritems()] for key, prop in l: self.__log("initialize prop " + key) if getattr(prop, 'key', None) is None: prop.init(key, self) self.__log("__initialize_properties() complete") self.__props_init = True def __compile_extensions(self): """Go through the global_extensions list as well as the list of ``MapperExtensions`` specified for this ``Mapper`` and creates a linked list of those extensions. """ extlist = util.OrderedSet() extension = self.extension if extension: for ext_obj in util.to_list(extension): # local MapperExtensions have already instrumented the class extlist.add(ext_obj) if self.inherits: for ext in self.inherits.extension: if ext not in extlist: extlist.add(ext) ext.instrument_class(self, self.class_) else: for ext in global_extensions: if isinstance(ext, type): ext = ext() if ext not in extlist: extlist.add(ext) ext.instrument_class(self, self.class_) self.extension = ExtensionCarrier() for ext in extlist: self.extension.append(ext) def __compile_inheritance(self): """Configure settings related to inheriting and/or inherited mappers being present.""" if self.inherits: if isinstance(self.inherits, type): self.inherits = class_mapper(self.inherits, compile=False) else: self.inherits = self.inherits if not issubclass(self.class_, self.inherits.class_): raise exceptions.ArgumentError("Class '%s' does not inherit from '%s'" % (self.class_.__name__, self.inherits.class_.__name__)) if self.non_primary != self.inherits.non_primary: np = not self.non_primary and "primary" or "non-primary" raise exceptions.ArgumentError("Inheritance of %s mapper for class '%s' is only allowed from a %s mapper" % (np, self.class_.__name__, np)) # inherit_condition is optional. if self.local_table is None: self.local_table = self.inherits.local_table self.mapped_table = self.inherits.mapped_table self.single = True elif not self.local_table is self.inherits.local_table: if self.concrete: self.mapped_table = self.local_table for mapper in self.iterate_to_root(): if mapper.polymorphic_on: mapper._requires_row_aliasing = True else: if self.inherit_condition is None: # figure out inherit condition from our table to the immediate table # of the inherited mapper, not its full table which could pull in other # stuff we dont want (allows test/inheritance.InheritTest4 to pass) self.inherit_condition = sqlutil.join_condition(self.inherits.local_table, self.local_table) self.mapped_table = sql.join(self.inherits.mapped_table, self.local_table, self.inherit_condition) fks = util.to_set(self.inherit_foreign_keys) self.__inherits_equated_pairs = sqlutil.criterion_as_pairs(self.mapped_table.onclause, consider_as_foreign_keys=fks) else: self.mapped_table = self.local_table if self.polymorphic_identity is not None: self.inherits.polymorphic_map[self.polymorphic_identity] = self if self.polymorphic_on is None: for mapper in self.iterate_to_root(): # try to set up polymorphic on using correesponding_column(); else leave # as None if mapper.polymorphic_on: self.polymorphic_on = self.mapped_table.corresponding_column(mapper.polymorphic_on) break else: # TODO: this exception not covered raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity)) if self.polymorphic_identity and not self.concrete: self._identity_class = self.inherits._identity_class else: self._identity_class = self.class_ if self.version_id_col is None: self.version_id_col = self.inherits.version_id_col for mapper in self.iterate_to_root(): util.reset_cached(mapper, '_equivalent_columns') if self.order_by is False: self.order_by = self.inherits.order_by self.polymorphic_map = self.inherits.polymorphic_map self.batch = self.inherits.batch self.inherits._inheriting_mappers.add(self) self.base_mapper = self.inherits.base_mapper self._all_tables = self.inherits._all_tables else: self._all_tables = util.Set() self.base_mapper = self self.mapped_table = self.local_table if self.polymorphic_identity: if self.polymorphic_on is None: raise exceptions.ArgumentError("Mapper '%s' specifies a polymorphic_identity of '%s', but no mapper in it's hierarchy specifies the 'polymorphic_on' column argument" % (str(self), self.polymorphic_identity)) self.polymorphic_map[self.polymorphic_identity] = self self._identity_class = self.class_ if self.mapped_table is None: raise exceptions.ArgumentError("Mapper '%s' does not have a mapped_table specified. (Are you using the return value of table.create()? It no longer has a return value.)" % str(self)) def __compile_pks(self): self.tables = sqlutil.find_tables(self.mapped_table) if not self.tables: raise exceptions.InvalidRequestError("Could not find any Table objects in mapped table '%s'" % str(self.mapped_table)) self._pks_by_table = {} self._cols_by_table = {} all_cols = util.Set(chain(*[col.proxy_set for col in self._columntoproperty])) pk_cols = util.Set([c for c in all_cols if c.primary_key]) # identify primary key columns which are also mapped by this mapper. for t in util.Set(self.tables + [self.mapped_table]): self._all_tables.add(t) if t.primary_key and pk_cols.issuperset(t.primary_key): # ordering is important since it determines the ordering of mapper.primary_key (and therefore query.get()) self._pks_by_table[t] = util.OrderedSet(t.primary_key).intersection(pk_cols) self._cols_by_table[t] = util.OrderedSet(t.c).intersection(all_cols) # if explicit PK argument sent, add those columns to the primary key mappings if self.primary_key_argument: for k in self.primary_key_argument: if k.table not in self._pks_by_table: self._pks_by_table[k.table] = util.OrderedSet() self._pks_by_table[k.table].add(k) if self.mapped_table not in self._pks_by_table or len(self._pks_by_table[self.mapped_table]) == 0: raise exceptions.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description)) if self.inherits and not self.concrete and not self.primary_key_argument: # if inheriting, the "primary key" for this mapper is that of the inheriting (unless concrete or explicit) self.primary_key = self.inherits.primary_key else: # determine primary key from argument or mapped_table pks - reduce to the minimal set of columns if self.primary_key_argument: primary_key = sqlutil.reduce_columns([self.mapped_table.corresponding_column(c) for c in self.primary_key_argument]) else: primary_key = sqlutil.reduce_columns(self._pks_by_table[self.mapped_table]) if len(primary_key) == 0: raise exceptions.ArgumentError("Mapper %s could not assemble any primary key columns for mapped table '%s'" % (self, self.mapped_table.description)) self.primary_key = primary_key self.__log("Identified primary key columns: " + str(primary_key)) def _get_clause(self): """create a "get clause" based on the primary key. this is used by query.get() and many-to-one lazyloads to load this item by primary key. """ params = [(primary_key, sql.bindparam(None, type_=primary_key.type)) for primary_key in self.primary_key] return sql.and_(*[k==v for (k, v) in params]), dict(params) _get_clause = property(util.cache_decorator(_get_clause)) def _equivalent_columns(self): """Create a map of all *equivalent* columns, based on the determination of column pairs that are equated to one another either by an established foreign key relationship or by a joined-table inheritance join. This is used to determine the minimal set of primary key columns for the mapper, as well as when relating columns to those of a polymorphic selectable (i.e. a UNION of several mapped tables), as that selectable usually only contains one column in its columns clause out of a group of several which are equated to each other. The resulting structure is a dictionary of columns mapped to lists of equivalent columns, i.e. { tablea.col1: set([tableb.col1, tablec.col1]), tablea.col2: set([tabled.col2]) } """ result = {} def visit_binary(binary): if binary.operator == operators.eq: if binary.left in result: result[binary.left].add(binary.right) else: result[binary.left] = util.Set([binary.right]) if binary.right in result: result[binary.right].add(binary.left) else: result[binary.right] = util.Set([binary.left]) for mapper in self.base_mapper.polymorphic_iterator(): if mapper.inherit_condition: visitors.traverse(mapper.inherit_condition, visit_binary=visit_binary) # TODO: matching of cols to foreign keys might better be generalized # into general column translation (i.e. corresponding_column) # recursively descend into the foreign key collection of the given column # and assemble each FK-related col as an "equivalent" for the given column def equivs(col, recursive, equiv): if col in recursive: return recursive.add(col) for fk in col.foreign_keys: if fk.column not in result: result[fk.column] = util.Set() result[fk.column].add(equiv) equivs(fk.column, recursive, col) for column in (self.primary_key_argument or self._pks_by_table[self.mapped_table]): for col in column.proxy_set: if not col.foreign_keys: if col not in result: result[col] = util.Set() result[col].add(col) else: equivs(col, util.Set(), col) return result _equivalent_columns = property(util.cache_decorator(_equivalent_columns)) class _CompileOnAttr(PropComparator): """A placeholder descriptor which triggers compilation on access.""" def __init__(self, class_, key): self.class_ = class_ self.key = key self.existing_prop = getattr(class_, key, None) def __getattribute__(self, key): cls = object.__getattribute__(self, 'class_') clskey = object.__getattribute__(self, 'key') if key.startswith('__'): return object.__getattribute__(self, key) class_mapper(cls) if cls.__dict__.get(clskey) is self: # FIXME: there should not be any scenarios where # a mapper compile leaves this CompileOnAttr in # place. util.warn( ("Attribute '%s' on class '%s' was not replaced during " "mapper compilation operation") % (clskey, cls.__name__)) # clean us up explicitly delattr(cls, clskey) return getattr(getattr(cls, clskey), key) def __compile_properties(self): # object attribute names mapped to MapperProperty objects self.__props = util.OrderedDict() # table columns mapped to lists of MapperProperty objects # using a list allows a single column to be defined as # populating multiple object attributes self._columntoproperty = {} # load custom properties if self._init_properties: for key, prop in self._init_properties.iteritems(): self._compile_property(key, prop, False) # pull properties from the inherited mapper if any. if self.inherits: for key, prop in self.inherits.__props.iteritems(): if key not in self.__props: self._adapt_inherited_property(key, prop) # create properties for each column in the mapped table, # for those columns which don't already map to a property for column in self.mapped_table.columns: if column in self._columntoproperty: continue if (self.include_properties is not None and column.key not in self.include_properties): self.__log("not including property %s" % (column.key)) continue if (self.exclude_properties is not None and column.key in self.exclude_properties): self.__log("excluding property %s" % (column.key)) continue column_key = (self.column_prefix or '') + column.key self._compile_property(column_key, column, init=False, setparent=True) # do a special check for the "discriminiator" column, as it may only be present # in the 'with_polymorphic' selectable but we need it for the base mapper if self.polymorphic_on and self.polymorphic_on not in self._columntoproperty: col = self.mapped_table.corresponding_column(self.polymorphic_on) or self.polymorphic_on self._compile_property(col.key, ColumnProperty(col), init=False, setparent=True) def _adapt_inherited_property(self, key, prop): if not self.concrete: self._compile_property(key, prop, init=False, setparent=False) # TODO: concrete properties dont adapt at all right now....will require copies of relations() etc. def _compile_property(self, key, prop, init=True, setparent=True): self.__log("_compile_property(%s, %s)" % (key, prop.__class__.__name__)) if not isinstance(prop, MapperProperty): # we were passed a Column or a list of Columns; generate a ColumnProperty columns = util.to_list(prop) column = columns[0] if not expression.is_column(column): raise exceptions.ArgumentError("%s=%r is not an instance of MapperProperty or Column" % (key, prop)) prop = self.__props.get(key, None) if isinstance(prop, ColumnProperty): # TODO: the "property already exists" case is still not well defined here. # assuming single-column, etc. if prop.parent is not self: # existing ColumnProperty from an inheriting mapper. # make a copy and append our column to it prop = prop.copy() prop.columns.append(column) self.__log("appending to existing ColumnProperty %s" % (key)) elif prop is None: mapped_column = [] for c in columns: mc = self.mapped_table.corresponding_column(c) if not mc: raise exceptions.ArgumentError("Column '%s' is not represented in mapper's table. Use the `column_property()` function to force this column to be mapped as a read-only attribute." % str(c)) mapped_column.append(mc) prop = ColumnProperty(*mapped_column) else: if not self.allow_column_override: raise exceptions.ArgumentError("WARNING: column '%s' not being added due to property '%s'. Specify 'allow_column_override=True' to mapper() to ignore this condition." % (column.key, repr(prop))) else: return if isinstance(prop, ColumnProperty): col = self.mapped_table.corresponding_column(prop.columns[0]) # col might not be present! the selectable given to the mapper need not include "deferred" # columns (included in zblog tests) if col is None: col = prop.columns[0] else: # if column is coming in after _cols_by_table was initialized, ensure the col is in the # right set if hasattr(self, '_cols_by_table') and col.table in self._cols_by_table and col not in self._cols_by_table[col.table]: self._cols_by_table[col.table].add(col) self.columns[key] = col for col in prop.columns: for col in col.proxy_set: self._columntoproperty[col] = prop elif isinstance(prop, SynonymProperty) and setparent: if prop.descriptor is None: prop.descriptor = getattr(self.class_, key, None) if isinstance(prop.descriptor, Mapper._CompileOnAttr): prop.descriptor = object.__getattribute__(prop.descriptor, 'existing_prop') if prop.map_column: if not key in self.mapped_table.c: raise exceptions.ArgumentError("Can't compile synonym '%s': no column on table '%s' named '%s'" % (prop.name, self.mapped_table.description, key)) self._compile_property(prop.name, ColumnProperty(self.mapped_table.c[key]), init=init, setparent=setparent) elif isinstance(prop, ComparableProperty) and setparent: # refactor me if prop.descriptor is None: prop.descriptor = getattr(self.class_, key, None) if isinstance(prop.descriptor, Mapper._CompileOnAttr): prop.descriptor = object.__getattribute__(prop.descriptor, 'existing_prop') self.__props[key] = prop if setparent: prop.set_parent(self) if not self.non_primary: setattr(self.class_, key, Mapper._CompileOnAttr(self.class_, key)) if init: prop.init(key, self) for mapper in self._inheriting_mappers: mapper._adapt_inherited_property(key, prop) def __compile_class(self): """If this mapper is to be a primary mapper (i.e. the non_primary flag is not set), associate this Mapper with the given class_ and entity name. Subsequent calls to ``class_mapper()`` for the class_/entity name combination will return this mapper. Also decorate the `__init__` method on the mapped class to include optional auto-session attachment logic. """ if self.non_primary: if not hasattr(self.class_, '_class_state'): raise exceptions.InvalidRequestError("Class %s has no primary mapper configured. Configure a primary mapper first before setting up a non primary Mapper.") self._class_state = self.class_._class_state _mapper_registry[self] = True return if not self.non_primary and '_class_state' in self.class_.__dict__ and (self.entity_name in self.class_._class_state.mappers): raise exceptions.ArgumentError("Class '%s' already has a primary mapper defined with entity name '%s'. Use non_primary=True to create a non primary Mapper. clear_mappers() will remove *all* current mappers from all classes." % (self.class_, self.entity_name)) def extra_init(class_, oldinit, instance, args, kwargs): self.compile() if 'init_instance' in self.extension.methods: self.extension.init_instance(self, class_, oldinit, instance, args, kwargs) def on_exception(class_, oldinit, instance, args, kwargs): util.warn_exception(self.extension.init_failed, self, class_, oldinit, instance, args, kwargs) attributes.register_class(self.class_, extra_init=extra_init, on_exception=on_exception, deferred_scalar_loader=_load_scalar_attributes) self._class_state = self.class_._class_state _mapper_registry[self] = True self.class_._class_state.mappers[self.entity_name] = self for ext in util.to_list(self.extension, []): ext.instrument_class(self, self.class_) if self.entity_name is None: self.class_.c = self.c def common_parent(self, other): """Return true if the given mapper shares a common inherited parent as this mapper.""" return self.base_mapper is other.base_mapper def isa(self, other): """Return True if the given mapper inherits from this mapper.""" m = other while m is not self and m.inherits: m = m.inherits return m is self def iterate_to_root(self): m = self while m: yield m m = m.inherits def polymorphic_iterator(self): """Iterate through the collection including this mapper and all descendant mappers. This includes not just the immediately inheriting mappers but all their inheriting mappers as well. To iterate through an entire hierarchy, use ``mapper.base_mapper.polymorphic_iterator()``.""" yield self for mapper in self._inheriting_mappers: for m in mapper.polymorphic_iterator(): yield m def add_properties(self, dict_of_properties): """Add the given dictionary of properties to this mapper, using `add_property`. """ for key, value in dict_of_properties.iteritems(): self.add_property(key, value) def add_property(self, key, prop): """Add an individual MapperProperty to this mapper. If the mapper has not been compiled yet, just adds the property to the initial properties dictionary sent to the constructor. If this Mapper has already been compiled, then the given MapperProperty is compiled immediately. """ self._init_properties[key] = prop self._compile_property(key, prop, init=self.__props_init) def __str__(self): return "Mapper|" + self.class_.__name__ + "|" + (self.entity_name is not None and "/%s" % self.entity_name or "") + (self.local_table and self.local_table.description or str(self.local_table)) + (self.non_primary and "|non-primary" or "") def primary_mapper(self): """Return the primary mapper corresponding to this mapper's class key (class + entity_name).""" return self._class_state.mappers[self.entity_name] def get_session(self): """Return the contextual session provided by the mapper extension chain, if any. Raise ``InvalidRequestError`` if a session cannot be retrieved from the extension chain. """ if 'get_session' in self.extension.methods: s = self.extension.get_session() if s is not EXT_CONTINUE: return s raise exceptions.InvalidRequestError("No contextual Session is established.") def instances(self, cursor, session, *mappers, **kwargs): """Return a list of mapped instances corresponding to the rows in a given ResultProxy. DEPRECATED. """ import sqlalchemy.orm.query return sqlalchemy.orm.Query(self, session).instances(cursor, *mappers, **kwargs) instances = util.deprecated(None, False)(instances) def identity_key_from_row(self, row): """Return an identity-map key for use in storing/retrieving an item from the identity map. row A ``sqlalchemy.engine.base.RowProxy`` instance or a dictionary corresponding result-set ``ColumnElement`` instances to their values within a row. """ return (self._identity_class, tuple([row[column] for column in self.primary_key]), self.entity_name) def identity_key_from_primary_key(self, primary_key): """Return an identity-map key for use in storing/retrieving an item from an identity map. primary_key A list of values indicating the identifier. """ return (self._identity_class, tuple(util.to_list(primary_key)), self.entity_name) def identity_key_from_instance(self, instance): """Return the identity key for the given instance, based on its primary key attributes. This value is typically also found on the instance itself under the attribute name `_instance_key`. """ return self.identity_key_from_primary_key(self.primary_key_from_instance(instance)) def _identity_key_from_state(self, state): return self.identity_key_from_primary_key(self._primary_key_from_state(state)) def primary_key_from_instance(self, instance): """Return the list of primary key values for the given instance. """ return [self._get_state_attr_by_column(instance._state, column) for column in self.primary_key] def _primary_key_from_state(self, state): return [self._get_state_attr_by_column(state, column) for column in self.primary_key] def _canload(self, state, allow_subtypes): if self.polymorphic_on or allow_subtypes: return self.isa(_state_mapper(state)) else: return state.class_ is self.class_ def _get_col_to_prop(self, column): try: return self._columntoproperty[column] except KeyError: prop = self.__props.get(column.key, None) if prop: raise exceptions.UnmappedColumnError("Column '%s.%s' is not available, due to conflicting property '%s':%s" % (column.table.name, column.name, column.key, repr(prop))) else: raise exceptions.UnmappedColumnError("No column %s.%s is configured on mapper %s..." % (column.table.name, column.name, str(self))) def _get_state_attr_by_column(self, state, column): return self._get_col_to_prop(column).getattr(state, column) def _set_state_attr_by_column(self, state, column, value): return self._get_col_to_prop(column).setattr(state, value, column) def _get_attr_by_column(self, obj, column): return self._get_col_to_prop(column).getattr(obj._state, column) def _get_committed_attr_by_column(self, obj, column): return self._get_col_to_prop(column).getcommitted(obj._state, column) def _set_attr_by_column(self, obj, column, value): self._get_col_to_prop(column).setattr(obj._state, column, value) def _save_obj(self, states, uowtransaction, postupdate=False, post_update_cols=None, single=False): """Issue ``INSERT`` and/or ``UPDATE`` statements for a list of objects. This is called within the context of a UOWTransaction during a flush operation. `_save_obj` issues SQL statements not just for instances mapped directly by this mapper, but for instances mapped by all inheriting mappers as well. This is to maintain proper insert ordering among a polymorphic chain of instances. Therefore _save_obj is typically called only on a *base mapper*, or a mapper which does not inherit from any other mapper. """ if self.__should_log_debug: self.__log_debug("_save_obj() start, " + (single and "non-batched" or "batched")) # if batch=false, call _save_obj separately for each object if not single and not self.batch: for state in states: self._save_obj([state], uowtransaction, postupdate=postupdate, post_update_cols=post_update_cols, single=True) return # if session has a connection callable, # organize individual states with the connection to use for insert/update if 'connection_callable' in uowtransaction.mapper_flush_opts: connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] tups = [(state, connection_callable(self, state.obj()), _state_has_identity(state)) for state in states] else: connection = uowtransaction.transaction.connection(self) tups = [(state, connection, _state_has_identity(state)) for state in states] if not postupdate: # call before_XXX extensions for state, connection, has_identity in tups: mapper = _state_mapper(state) if not has_identity: if 'before_insert' in mapper.extension.methods: mapper.extension.before_insert(mapper, connection, state.obj()) else: if 'before_update' in mapper.extension.methods: mapper.extension.before_update(mapper, connection, state.obj()) for state, connection, has_identity in tups: # detect if we have a "pending" instance (i.e. has no instance_key attached to it), # and another instance with the same identity key already exists as persistent. convert to an # UPDATE if so. mapper = _state_mapper(state) instance_key = mapper._identity_key_from_state(state) if not postupdate and not has_identity and instance_key in uowtransaction.uow.identity_map: existing = uowtransaction.uow.identity_map[instance_key]._state if not uowtransaction.is_deleted(existing): raise exceptions.FlushError("New instance %s with identity key %s conflicts with persistent instance %s" % (state_str(state), str(instance_key), state_str(existing))) if self.__should_log_debug: self.__log_debug("detected row switch for identity %s. will update %s, remove %s from transaction" % (instance_key, state_str(state), state_str(existing))) uowtransaction.set_row_switch(existing) inserted_objects = util.Set() updated_objects = util.Set() table_to_mapper = {} for mapper in self.base_mapper.polymorphic_iterator(): for t in mapper.tables: table_to_mapper[t] = mapper for table in sqlutil.sort_tables(table_to_mapper.keys()): insert = [] update = [] for state, connection, has_identity in tups: mapper = _state_mapper(state) if table not in mapper._pks_by_table: continue pks = mapper._pks_by_table[table] instance_key = mapper._identity_key_from_state(state) if self.__should_log_debug: self.__log_debug("_save_obj() table '%s' instance %s identity %s" % (table.name, state_str(state), str(instance_key))) isinsert = not instance_key in uowtransaction.uow.identity_map and not postupdate and not has_identity params = {} value_params = {} hasdata = False if isinsert: for col in mapper._cols_by_table[table]: if col is mapper.version_id_col: params[col.key] = 1 elif col in pks: value = mapper._get_state_attr_by_column(state, col) if value is not None: params[col.key] = value elif mapper.polymorphic_on and mapper.polymorphic_on.shares_lineage(col): if self.__should_log_debug: self.__log_debug("Using polymorphic identity '%s' for insert column '%s'" % (mapper.polymorphic_identity, col.key)) value = mapper.polymorphic_identity if col.default is None or value is not None: params[col.key] = value else: value = mapper._get_state_attr_by_column(state, col) if col.default is None or value is not None: if isinstance(value, sql.ClauseElement): value_params[col] = value else: params[col.key] = value insert.append((state, params, mapper, connection, value_params)) else: for col in mapper._cols_by_table[table]: if col is mapper.version_id_col: params[col._label] = mapper._get_state_attr_by_column(state, col) params[col.key] = params[col._label] + 1 for prop in mapper._columntoproperty.values(): (added, unchanged, deleted) = attributes.get_history(state, prop.key, passive=True) if added: hasdata = True elif mapper.polymorphic_on and mapper.polymorphic_on.shares_lineage(col): pass else: if post_update_cols is not None and col not in post_update_cols: if col in pks: params[col._label] = mapper._get_state_attr_by_column(state, col) continue prop = mapper._columntoproperty[col] (added, unchanged, deleted) = attributes.get_history(state, prop.key, passive=True) if added: if isinstance(added[0], sql.ClauseElement): value_params[col] = added[0] else: params[col.key] = prop.get_col_value(col, added[0]) if col in pks: if deleted: params[col._label] = prop.get_col_value(col, deleted[0]) else: # row switch logic can reach us here params[col._label] = prop.get_col_value(col, added[0]) hasdata = True elif col in pks: params[col._label] = mapper._get_state_attr_by_column(state, col) if hasdata: update.append((state, params, mapper, connection, value_params)) if update: mapper = table_to_mapper[table] clause = sql.and_() for col in mapper._pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col._label, type_=col.type)) if mapper.version_id_col and table.c.contains_column(mapper.version_id_col): clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type)) statement = table.update(clause) pks = mapper._pks_by_table[table] def comparator(a, b): for col in pks: x = cmp(a[1][col._label],b[1][col._label]) if x != 0: return x return 0 update.sort(comparator) rows = 0 for rec in update: (state, params, mapper, connection, value_params) = rec c = connection.execute(statement.values(value_params), params) mapper.__postfetch(uowtransaction, connection, table, state, c, c.last_updated_params(), value_params) # testlib.pragma exempt:__hash__ updated_objects.add((state, connection)) rows += c.rowcount if c.supports_sane_rowcount() and rows != len(update): raise exceptions.ConcurrentModificationError("Updated rowcount %d does not match number of objects updated %d" % (rows, len(update))) if insert: statement = table.insert() def comparator(a, b): return cmp(a[0].insert_order, b[0].insert_order) insert.sort(comparator) for rec in insert: (state, params, mapper, connection, value_params) = rec c = connection.execute(statement.values(value_params), params) primary_key = c.last_inserted_ids() if primary_key is not None: # set primary key attributes for i, col in enumerate(mapper._pks_by_table[table]): if mapper._get_state_attr_by_column(state, col) is None and len(primary_key) > i: mapper._set_state_attr_by_column(state, col, primary_key[i]) mapper.__postfetch(uowtransaction, connection, table, state, c, c.last_inserted_params(), value_params) # synchronize newly inserted ids from one table to the next # TODO: this fires off more than needed, try to organize syncrules # per table for m in util.reversed(list(mapper.iterate_to_root())): if m.__inherits_equated_pairs: m.__synchronize_inherited(state) # testlib.pragma exempt:__hash__ inserted_objects.add((state, connection)) if not postupdate: # call after_XXX extensions for state, connection, has_identity in tups: mapper = _state_mapper(state) if not has_identity: if 'after_insert' in mapper.extension.methods: mapper.extension.after_insert(mapper, connection, state.obj()) else: if 'after_update' in mapper.extension.methods: mapper.extension.after_update(mapper, connection, state.obj()) def __synchronize_inherited(self, state): sync.populate(state, self, state, self, self.__inherits_equated_pairs) def __postfetch(self, uowtransaction, connection, table, state, resultproxy, params, value_params): """After an ``INSERT`` or ``UPDATE``, assemble newly generated values on an instance. For columns which are marked as being generated on the database side, set up a group-based "deferred" loader which will populate those attributes in one query when next accessed. """ postfetch_cols = resultproxy.postfetch_cols() generated_cols = list(resultproxy.prefetch_cols()) if self.polymorphic_on: po = table.corresponding_column(self.polymorphic_on) if po: generated_cols.append(po) if self.version_id_col: generated_cols.append(self.version_id_col) for c in generated_cols: if c.key in params: self._set_state_attr_by_column(state, c, params[c.key]) deferred_props = [prop.key for prop in [self._columntoproperty[c] for c in postfetch_cols]] if deferred_props: if self.eager_defaults: _instance_key = self._identity_key_from_state(state) state.dict['_instance_key'] = _instance_key uowtransaction.session.query(self)._get(_instance_key, refresh_instance=state, only_load_props=deferred_props) else: _expire_state(state, deferred_props) def _delete_obj(self, states, uowtransaction): """Issue ``DELETE`` statements for a list of objects. This is called within the context of a UOWTransaction during a flush operation. """ if self.__should_log_debug: self.__log_debug("_delete_obj() start") if 'connection_callable' in uowtransaction.mapper_flush_opts: connection_callable = uowtransaction.mapper_flush_opts['connection_callable'] tups = [(state, connection_callable(self, state.obj())) for state in states] else: connection = uowtransaction.transaction.connection(self) tups = [(state, connection) for state in states] for (state, connection) in tups: mapper = _state_mapper(state) if 'before_delete' in mapper.extension.methods: mapper.extension.before_delete(mapper, connection, state.obj()) deleted_objects = util.Set() table_to_mapper = {} for mapper in self.base_mapper.polymorphic_iterator(): for t in mapper.tables: table_to_mapper[t] = mapper for table in sqlutil.sort_tables(table_to_mapper.keys(), reverse=True): delete = {} for (state, connection) in tups: mapper = _state_mapper(state) if table not in mapper._pks_by_table: continue params = {} if not _state_has_identity(state): continue else: delete.setdefault(connection, []).append(params) for col in mapper._pks_by_table[table]: params[col.key] = mapper._get_state_attr_by_column(state, col) if mapper.version_id_col and table.c.contains_column(mapper.version_id_col): params[mapper.version_id_col.key] = mapper._get_state_attr_by_column(state, mapper.version_id_col) # testlib.pragma exempt:__hash__ deleted_objects.add((state, connection)) for connection, del_objects in delete.iteritems(): mapper = table_to_mapper[table] def comparator(a, b): for col in mapper._pks_by_table[table]: x = cmp(a[col.key],b[col.key]) if x != 0: return x return 0 del_objects.sort(comparator) clause = sql.and_() for col in mapper._pks_by_table[table]: clause.clauses.append(col == sql.bindparam(col.key, type_=col.type)) if mapper.version_id_col and table.c.contains_column(mapper.version_id_col): clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type)) statement = table.delete(clause) c = connection.execute(statement, del_objects) if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects): raise exceptions.ConcurrentModificationError("Deleted rowcount %d does not match number of objects deleted %d" % (c.rowcount, len(del_objects))) for state, connection in deleted_objects: mapper = _state_mapper(state) if 'after_delete' in mapper.extension.methods: mapper.extension.after_delete(mapper, connection, state.obj()) def _register_dependencies(self, uowcommit): """Register ``DependencyProcessor`` instances with a ``unitofwork.UOWTransaction``. This call `register_dependencies` on all attached ``MapperProperty`` instances. """ for prop in self.__props.values(): prop.register_dependencies(uowcommit) for dep in self._dependency_processors: dep.register_dependencies(uowcommit) def cascade_iterator(self, type_, state, halt_on=None): """Iterate each element and its mapper in an object graph, for all relations that meet the given cascade rule. type\_ The name of the cascade rule (i.e. save-update, delete, etc.) state The lead InstanceState. child items will be processed per the relations defined for this object's mapper. the return value are object instances; this provides a strong reference so that they don't fall out of scope immediately. """ visited_instances = util.IdentitySet() visitables = [(self.__props.itervalues(), 'property', state)] while visitables: iterator,item_type,parent_state = visitables[-1] try: if item_type == 'property': prop = iterator.next() visitables.append((prop.cascade_iterator(type_, parent_state, visited_instances, halt_on), 'mapper', None)) elif item_type == 'mapper': instance, instance_mapper, corresponding_state = iterator.next() yield (instance, instance_mapper) visitables.append((instance_mapper.__props.itervalues(), 'property', corresponding_state)) except StopIteration: visitables.pop() def _instance(self, context, row, result=None, polymorphic_from=None, extension=None, only_load_props=None, refresh_instance=None): if not extension: extension = self.extension if 'translate_row' in extension.methods: ret = extension.translate_row(self, context, row) if ret is not EXT_CONTINUE: row = ret if polymorphic_from: # if we are called from a base mapper doing a polymorphic load, figure out what tables, # if any, will need to be "post-fetched" based on the tables present in the row, # or from the options set up on the query if ('polymorphic_fetch', self) not in context.attributes: if self in context.query._with_polymorphic: context.attributes[('polymorphic_fetch', self)] = (polymorphic_from, []) else: context.attributes[('polymorphic_fetch', self)] = (polymorphic_from, [t for t in self.tables if t not in polymorphic_from.tables]) elif not refresh_instance and self.polymorphic_on: discriminator = row[self.polymorphic_on] if discriminator is not None: try: mapper = self.polymorphic_map[discriminator] except KeyError: raise exceptions.AssertionError("No such polymorphic_identity %r is defined" % discriminator) if mapper is not self: return mapper._instance(context, row, result=result, polymorphic_from=self) # determine identity key if refresh_instance: try: identitykey = refresh_instance.dict['_instance_key'] except KeyError: # super-rare condition; a refresh is being called # on a non-instance-key instance; this is meant to only # occur wihtin a flush() identitykey = self._identity_key_from_state(refresh_instance) else: identitykey = self.identity_key_from_row(row) session_identity_map = context.session.identity_map if identitykey in session_identity_map: instance = session_identity_map[identitykey] state = instance._state if self.__should_log_debug: self.__log_debug("_instance(): using existing instance %s identity %s" % (instance_str(instance), str(identitykey))) isnew = state.runid != context.runid currentload = not isnew if not currentload and context.version_check and self.version_id_col and self._get_attr_by_column(instance, self.version_id_col) != row[self.version_id_col]: raise exceptions.ConcurrentModificationError("Instance '%s' version of %s does not match %s" % (instance, self._get_attr_by_column(instance, self.version_id_col), row[self.version_id_col])) elif refresh_instance: # out of band refresh_instance detected (i.e. its not in the session.identity_map) # honor it anyway. this can happen if a _get() occurs within save_obj(), such as # when eager_defaults is True. state = refresh_instance instance = state.obj() isnew = state.runid != context.runid currentload = True else: if self.__should_log_debug: self.__log_debug("_instance(): identity key %s not in session" % str(identitykey)) if self.allow_null_pks: for x in identitykey[1]: if x is not None: break else: return None else: if None in identitykey[1]: return None isnew = True currentload = True if 'create_instance' in extension.methods: instance = extension.create_instance(self, context, row, self.class_) if instance is EXT_CONTINUE: instance = attributes.new_instance(self.class_) else: attributes.manage(instance) else: instance = attributes.new_instance(self.class_) if self.__should_log_debug: self.__log_debug("_instance(): created new instance %s identity %s" % (instance_str(instance), str(identitykey))) state = instance._state instance._entity_name = self.entity_name instance._instance_key = identitykey instance._sa_session_id = context.session.hash_key session_identity_map[identitykey] = instance if currentload or context.populate_existing or self.always_refresh: if isnew: state.runid = context.runid context.progress.add(state) if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: self.populate_instance(context, instance, row, only_load_props=only_load_props, instancekey=identitykey, isnew=isnew) else: # populate attributes on non-loading instances which have been expired # TODO: also support deferred attributes here [ticket:870] if state.expired_attributes: if state in context.partials: isnew = False attrs = context.partials[state] else: isnew = True attrs = state.expired_attributes.intersection(state.unmodified) context.partials[state] = attrs #<-- allow query.instances to commit the subset of attrs if 'populate_instance' not in extension.methods or extension.populate_instance(self, context, row, instance, only_load_props=attrs, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE: self.populate_instance(context, instance, row, only_load_props=attrs, instancekey=identitykey, isnew=isnew) if result is not None and ('append_result' not in extension.methods or extension.append_result(self, context, row, instance, result, instancekey=identitykey, isnew=isnew) is EXT_CONTINUE): result.append(instance) return instance def populate_instance(self, selectcontext, instance, row, ispostselect=None, isnew=False, only_load_props=None, **flags): """populate an instance from a result row.""" snapshot = selectcontext.path + (self,) # retrieve a set of "row population" functions derived from the MapperProperties attached # to this Mapper. These are keyed in the select context based primarily off the # "snapshot" of the stack, which represents a path from the lead mapper in the query to this one, # including relation() names. the key also includes "self", and allows us to distinguish between # other mappers within our inheritance hierarchy (new_populators, existing_populators) = selectcontext.attributes.get(('populators', self, snapshot, ispostselect), (None, None)) if new_populators is None: # no populators; therefore this is the first time we are receiving a row for # this result set. issue create_row_processor() on all MapperProperty objects # and cache in the select context. new_populators = [] existing_populators = [] post_processors = [] for prop in self.__props.values(): (newpop, existingpop, post_proc) = selectcontext.exec_with_path(self, prop.key, prop.create_row_processor, selectcontext, self, row) if newpop: new_populators.append((prop.key, newpop)) if existingpop: existing_populators.append((prop.key, existingpop)) if post_proc: post_processors.append(post_proc) # install a post processor for immediate post-load of joined-table inheriting mappers poly_select_loader = self._get_poly_select_loader(selectcontext, row) if poly_select_loader: post_processors.append(poly_select_loader) selectcontext.attributes[('populators', self, snapshot, ispostselect)] = (new_populators, existing_populators) selectcontext.attributes[('post_processors', self, ispostselect)] = post_processors if isnew or ispostselect: populators = new_populators else: populators = existing_populators if only_load_props: populators = [p for p in populators if p[0] in only_load_props] for (key, populator) in populators: selectcontext.exec_with_path(self, key, populator, instance, row, ispostselect=ispostselect, isnew=isnew, **flags) if self.non_primary: selectcontext.attributes[('populating_mapper', instance._state)] = self def _post_instance(self, selectcontext, state, **kwargs): post_processors = selectcontext.attributes[('post_processors', self, None)] for p in post_processors: p(state.obj(), **kwargs) def _get_poly_select_loader(self, selectcontext, row): """set up attribute loaders for 'select' and 'deferred' polymorphic loading. this loading uses a second SELECT statement to load additional tables, either immediately after loading the main table or via a deferred attribute trigger. """ (hosted_mapper, needs_tables) = selectcontext.attributes.get(('polymorphic_fetch', self), (None, None)) if hosted_mapper is None or not needs_tables: return cond, param_names = self._deferred_inheritance_condition(hosted_mapper, needs_tables) statement = sql.select(needs_tables, cond, use_labels=True) if hosted_mapper.polymorphic_fetch == 'select': def post_execute(instance, **flags): if self.__should_log_debug: self.__log_debug("Post query loading instance " + instance_str(instance)) identitykey = self.identity_key_from_instance(instance) only_load_props = flags.get('only_load_props', None) params = {} for c, bind in param_names: params[bind] = self._get_attr_by_column(instance, c) row = selectcontext.session.connection(self).execute(statement, params).fetchone() self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True, only_load_props=only_load_props) return post_execute elif hosted_mapper.polymorphic_fetch == 'deferred': from sqlalchemy.orm.strategies import DeferredColumnLoader def post_execute(instance, **flags): def create_statement(instance): params = {} for (c, bind) in param_names: # use the "committed" (database) version to get query column values params[bind] = self._get_committed_attr_by_column(instance, c) return (statement, params) props = [prop for prop in [self._get_col_to_prop(col) for col in statement.inner_columns] if prop.key not in instance.__dict__] keys = [p.key for p in props] only_load_props = flags.get('only_load_props', None) if only_load_props: keys = util.Set(keys).difference(only_load_props) props = [p for p in props if p.key in only_load_props] for prop in props: strategy = prop._get_strategy(DeferredColumnLoader) instance._state.set_callable(prop.key, strategy.setup_loader(instance, props=keys, create_statement=create_statement)) return post_execute else: return None def _deferred_inheritance_condition(self, base_mapper, needs_tables): base_mapper = base_mapper.primary_mapper() def visit_binary(binary): leftcol = binary.left rightcol = binary.right if leftcol is None or rightcol is None: return if leftcol.table not in needs_tables: binary.left = sql.bindparam(None, None, type_=binary.right.type) param_names.append((leftcol, binary.left)) elif rightcol.table not in needs_tables: binary.right = sql.bindparam(None, None, type_=binary.right.type) param_names.append((rightcol, binary.right)) allconds = [] param_names = [] for mapper in self.iterate_to_root(): if mapper is base_mapper: break if not mapper.single: allconds.append(visitors.traverse(mapper.inherit_condition, clone=True, visit_binary=visit_binary)) return sql.and_(*allconds), param_names Mapper.logger = logging.class_logger(Mapper) object_session = None def _load_scalar_attributes(instance, attribute_names): mapper = object_mapper(instance) global object_session if not object_session: from sqlalchemy.orm.session import object_session session = object_session(instance) if not session: try: session = mapper.get_session() except exceptions.InvalidRequestError: raise exceptions.UnboundExecutionError("Instance %s is not bound to a Session, and no contextual session is established; attribute refresh operation cannot proceed" % (instance.__class__)) state = instance._state if '_instance_key' in state.dict: identity_key = state.dict['_instance_key'] shouldraise = True else: # if instance is pending, a refresh operation may not complete (even if PK attributes are assigned) shouldraise = False identity_key = mapper._identity_key_from_state(state) if session.query(mapper)._get(identity_key, refresh_instance=state, only_load_props=attribute_names) is None and shouldraise: raise exceptions.InvalidRequestError("Could not refresh instance '%s'" % instance_str(instance))