diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/orm/mapper.py | 20 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/strategies.py | 19 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 29 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/expression.py | 23 |
4 files changed, 46 insertions, 45 deletions
diff --git a/lib/sqlalchemy/orm/mapper.py b/lib/sqlalchemy/orm/mapper.py index 18c639af0..5d727bd6a 100644 --- a/lib/sqlalchemy/orm/mapper.py +++ b/lib/sqlalchemy/orm/mapper.py @@ -1066,9 +1066,9 @@ class Mapper(object): mapper = table_to_mapper[table] clause = sql.and_() for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col._label, type_=col.type, unique=True)) + clause.clauses.append(col == sql.bindparam(col._label, type_=col.type)) if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): - clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type, unique=True)) + clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col._label, type_=col.type)) statement = table.update(clause) rows = 0 supports_sane_rowcount = True @@ -1210,9 +1210,9 @@ class Mapper(object): del_objects.sort(comparator) clause = sql.and_() for col in mapper._pks_by_table[table]: - clause.clauses.append(col == sql.bindparam(col.key, type_=col.type, unique=True)) + clause.clauses.append(col == sql.bindparam(col.key, type_=col.type)) if mapper.version_id_col is not None and table.c.contains_column(mapper.version_id_col): - clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type, unique=True)) + clause.clauses.append(mapper.version_id_col == sql.bindparam(mapper.version_id_col.key, type_=mapper.version_id_col.type)) statement = table.delete(clause) c = connection.execute(statement, del_objects) if c.supports_sane_multi_rowcount() and c.rowcount != len(del_objects): @@ -1389,11 +1389,11 @@ class Mapper(object): if leftcol is None or rightcol is None: return if leftcol.table not in needs_tables: - binary.left = sql.bindparam(leftcol.name, None, type_=binary.right.type, unique=True) - param_names.append(leftcol) + binary.left = sql.bindparam(None, None, type_=binary.right.type) + param_names.append((leftcol, binary.left)) elif rightcol not in needs_tables: - binary.right = sql.bindparam(rightcol.name, None, type_=binary.right.type, unique=True) - param_names.append(rightcol) + binary.right = sql.bindparam(None, None, type_=binary.right.type) + param_names.append((rightcol, binary.right)) allconds = [] param_names = [] @@ -1487,8 +1487,8 @@ class Mapper(object): identitykey = self.identity_key_from_instance(instance) params = {} - for c in param_names: - params[c.name] = self._get_attr_by_column(instance, c) + for c, bind in param_names: + params[bind] = self._get_attr_by_column(instance, c) row = selectcontext.session.connection(self).execute(statement, params).fetchone() self.populate_instance(selectcontext, instance, row, isnew=False, instancekey=identitykey, ispostselect=True) diff --git a/lib/sqlalchemy/orm/strategies.py b/lib/sqlalchemy/orm/strategies.py index 3c647ac60..5a765fbd3 100644 --- a/lib/sqlalchemy/orm/strategies.py +++ b/lib/sqlalchemy/orm/strategies.py @@ -108,8 +108,8 @@ class ColumnLoader(LoaderStrategy): statement = sql.select(needs_tables, cond, use_labels=True) def create_statement(instance): params = {} - for c in param_names: - params[c.name] = mapper._get_attr_by_column(instance, c) + for (c, bind) in param_names: + params[bind] = mapper._get_attr_by_column(instance, c) return (statement, params) def new_execute(instance, row, isnew, **flags): @@ -297,12 +297,11 @@ class LazyLoader(AbstractRelationLoader): (criterion, lazybinds, rev) = LazyLoader._create_lazy_clause(self.parent_property, reverse_direction=reverse_direction) bind_to_col = dict([(lazybinds[col].key, col) for col in lazybinds]) - class Visitor(visitors.ClauseVisitor): - def visit_bindparam(s, bindparam): - mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent - if bindparam.key in bind_to_col: - bindparam.value = mapper._get_attr_by_column(instance, bind_to_col[bindparam.key]) - return Visitor().traverse(criterion, clone=True) + def visit_bindparam(bindparam): + mapper = reverse_direction and self.parent_property.mapper or self.parent_property.parent + if bindparam.key in bind_to_col: + bindparam.value = mapper._get_attr_by_column(instance, bind_to_col[bindparam.key]) + return visitors.traverse(criterion, clone=True, visit_bindparam=visit_bindparam) def setup_loader(self, instance, options=None, path=None): if not mapper.has_mapper(instance): @@ -416,7 +415,7 @@ class LazyLoader(AbstractRelationLoader): if should_bind(leftcol, rightcol): col = leftcol binary.left = binds.setdefault(leftcol, - sql.bindparam(None, None, type_=binary.right.type, unique=True)) + sql.bindparam(None, None, type_=binary.right.type)) reverse[rightcol] = binds[col] # the "left is not right" compare is to handle part of a join clause that is "table.c.col1==table.c.col1", @@ -424,7 +423,7 @@ class LazyLoader(AbstractRelationLoader): if leftcol is not rightcol and should_bind(rightcol, leftcol): col = rightcol binary.right = binds.setdefault(rightcol, - sql.bindparam(None, None, type_=binary.left.type, unique=True)) + sql.bindparam(None, None, type_=binary.left.type)) reverse[leftcol] = binds[col] lazywhere = primaryjoin diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 749ce4c10..aec75e76c 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -196,7 +196,7 @@ class DefaultCompiler(engine.Compiled): if params: pd = {} for bindparam, name in self.bind_names.iteritems(): - for paramname in (bindparam.key, bindparam.shortname, name): + for paramname in (bindparam, bindparam.key, bindparam.shortname, name): if paramname in params: pd[name] = params[paramname] break @@ -373,26 +373,13 @@ class DefaultCompiler(engine.Compiled): return self.operators.get(operator, str(operator)) def visit_bindparam(self, bindparam, **kwargs): - # TODO: remove this whole "unique" thing, just use regular - # anonymous params to implement. params used for inserts/updates - # etc. should no longer be "unique". - if bindparam.unique: - count = 1 - key = bindparam.key - # redefine the generated name of the bind param in the case - # that we have multiple conflicting bind parameters. - while self.binds.setdefault(key, bindparam) is not bindparam: - tag = "_%d" % count - key = bindparam.key + tag - count += 1 - bindparam.key = key - return self.bindparam_string(self._truncate_bindparam(bindparam)) - else: - existing = self.binds.get(bindparam.key) - if existing is not None and existing.unique: + name = self._truncate_bindparam(bindparam) + if name in self.binds: + existing = self.binds[name] + if existing.unique or bindparam.unique: raise exceptions.CompileError("Bind parameter '%s' conflicts with unique bind parameter of the same name" % bindparam.key) - self.binds[bindparam.key] = bindparam - return self.bindparam_string(self._truncate_bindparam(bindparam)) + self.binds[bindparam.key] = self.binds[name] = bindparam + return self.bindparam_string(name) def _truncate_bindparam(self, bindparam): if bindparam in self.bind_names: @@ -632,7 +619,7 @@ class DefaultCompiler(engine.Compiled): """ def create_bind_param(col, value): - bindparam = sql.bindparam(col.key, value, type_=col.type, unique=True) + bindparam = sql.bindparam(col.key, value, type_=col.type) self.binds[col.key] = bindparam return self.bindparam_string(self._truncate_bindparam(bindparam)) diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py index 6ad29218f..732d4fdf9 100644 --- a/lib/sqlalchemy/sql/expression.py +++ b/lib/sqlalchemy/sql/expression.py @@ -922,7 +922,7 @@ class ClauseElement(object): if bind.key in kwargs: bind.value = kwargs[bind.key] if unique: - bind.unique=True + bind._convert_to_unique() return Vis().traverse(self, clone=True) def compare(self, other): @@ -1749,10 +1749,14 @@ class _BindParamClause(ClauseElement, _CompareMixin): if True, the parameter should be treated like a stored procedure "OUT" parameter. """ - - self.key = key or "{ANON %d param}" % id(self) - self.value = value + + if unique: + self.key = "{ANON %d %s}" % (id(self), key or 'param') + else: + self.key = key or "{ANON %d param}" % id(self) + self._orig_key = key self.unique = unique + self.value = value self.isoutparam = isoutparam self.shortname = shortname @@ -1778,6 +1782,17 @@ class _BindParamClause(ClauseElement, _CompareMixin): type(None):sqltypes.NullType } + def _clone(self): + c = ClauseElement._clone(self) + if self.unique: + c.key = "{ANON %d %s}" % (id(c), c._orig_key or 'param') + return c + + def _convert_to_unique(self): + if not self.unique: + self.unique=True + self.key = "{ANON %d %s}" % (id(self), self._orig_key or 'param') + def _get_from_objects(self, **modifiers): return [] |
