diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-09-10 23:52:04 +0000 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2006-09-10 23:52:04 +0000 |
| commit | 47d8b03b14145997fc0936bd674363f0e213f019 (patch) | |
| tree | 64e9356d7c5f435b68d7deaa8b41c6a4cdc5c521 /lib/sqlalchemy | |
| parent | 287bf217958fbccb07cffafcc4481a2b6c7f2784 (diff) | |
| download | sqlalchemy-47d8b03b14145997fc0936bd674363f0e213f019.tar.gz | |
- changed "for_update" parameter to accept False/True/"nowait"
and "read", the latter two of which are interpreted only by
Oracle and Mysql [ticket:292]
- added "lockmode" argument to base Query select/get functions,
including "with_lockmode" function to get a Query copy that has
a default locking mode. Will translate "read"/"update"
arguments into a for_update argument on the select side.
[ticket:292]
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/ansisql.py | 14 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/mysql.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/databases/oracle.py | 6 | ||||
| -rw-r--r-- | lib/sqlalchemy/orm/query.py | 28 |
4 files changed, 39 insertions, 15 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index c44595f36..66b917c20 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -394,13 +394,9 @@ class ANSICompiler(sql.Compiled): text += " ORDER BY " + order_by text += self.visit_select_postclauses(select) - - if select.for_update: - text += " FOR UPDATE" - if select.nowait: - text += " NOWAIT" - + text += self.for_update_clause(select) + if getattr(select, 'parens', False): self.strings[select] = "(" + text + ")" else: @@ -415,6 +411,12 @@ class ANSICompiler(sql.Compiled): """ called when building a SELECT statement, position is after all other SELECT clauses. Most DB syntaxes put LIMIT/OFFSET here """ return (select.limit or select.offset) and self.limit_clause(select) or "" + def for_update_clause(self, select): + if select.for_update: + return " FOR UPDATE" + else: + return "" + def limit_clause(self, select): text = "" if select.limit is not None: diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py index c6d78cf90..4eab9e55c 100644 --- a/lib/sqlalchemy/databases/mysql.py +++ b/lib/sqlalchemy/databases/mysql.py @@ -411,6 +411,12 @@ class MySQLCompiler(ansisql.ANSICompiler): # TODO: put whatever MySQL does for CAST here. self.strings[cast] = self.strings[cast.clause] + def for_update_clause(self, select): + if select.for_update == 'read': + return ' LOCK IN SHARE MODE' + else: + return super(MySQLCompiler, self).for_update_clause(select) + def limit_clause(self, select): text = "" if select.limit is not None: diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py index 5f574338b..5db157cbb 100644 --- a/lib/sqlalchemy/databases/oracle.py +++ b/lib/sqlalchemy/databases/oracle.py @@ -402,6 +402,12 @@ class OracleCompiler(ansisql.ANSICompiler): def limit_clause(self, select): return "" + def for_update_clause(self, select): + if select.for_update=="nowait": + return " FOR UPDATE NOWAIT" + else: + return super(OracleCompiler, self).for_update_clause(select) + class OracleSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py index 29cc56761..d35219208 100644 --- a/lib/sqlalchemy/orm/query.py +++ b/lib/sqlalchemy/orm/query.py @@ -1,4 +1,4 @@ -# orm/query.py + # orm/query.py # Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com # # This module is part of SQLAlchemy and is released under @@ -12,7 +12,7 @@ import mapper class Query(object): """encapsulates the object-fetching operations provided by Mappers.""" - def __init__(self, class_or_mapper, session=None, entity_name=None, **kwargs): + def __init__(self, class_or_mapper, session=None, entity_name=None, lockmode=None, **kwargs): if isinstance(class_or_mapper, type): self.mapper = mapper.class_mapper(class_or_mapper, entity_name=entity_name) else: @@ -20,6 +20,7 @@ class Query(object): self.mapper = self.mapper.get_select_mapper().compile() self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh) self.order_by = kwargs.pop('order_by', self.mapper.order_by) + self.lockmode = lockmode self.extension = kwargs.pop('extension', self.mapper.extension) self._session = session if not hasattr(self.mapper, '_get_clause'): @@ -67,7 +68,8 @@ class Query(object): e.g. u = usermapper.get_by(user_name = 'fred') """ - x = self.select_whereclause(self.join_by(*args, **params), limit=1) + lockmode=params.pop('lockmode', self.lockmode) + x = self.select_whereclause(self.join_by(*args, **params), lockmode=lockmode, limit=1) if x: return x[0] else: @@ -248,7 +250,11 @@ class Query(object): def options(self, *args, **kwargs): """returns a new Query object using the given MapperOptions.""" return self.mapper.options(*args, **kwargs).using(session=self._session) - + + def with_lockmode(self, mode): + """return a new Query object with the specified locking mode.""" + return Query(self.mapper, self._session, lockmode=mode) + def __getattr__(self, key): if (key.startswith('select_by_')): key = key[10:] @@ -270,8 +276,9 @@ class Query(object): finally: result.close() - def _get(self, key, ident=None, reload=False): - if not reload and not self.always_refresh: + def _get(self, key, ident=None, reload=False, lockmode=None): + lockmode = lockmode or self.lockmode + if not reload and not self.always_refresh and lockmode == None: try: return self.session._get(key) except KeyError: @@ -293,7 +300,7 @@ class Query(object): if len(ident) > i + 1: i += 1 try: - statement = self.compile(self._get_clause) + statement = self.compile(self._get_clause, lockmode=lockmode) return self._select_statement(statement, params=params, populate_existing=reload)[0] except IndexError: return None @@ -320,11 +327,14 @@ class Query(object): def compile(self, whereclause = None, **kwargs): order_by = kwargs.pop('order_by', False) from_obj = kwargs.pop('from_obj', []) + lockmode = kwargs.pop('lockmode', self.lockmode) if order_by is False: order_by = self.order_by if order_by is False: if self.table.default_order_by() is not None: order_by = self.table.default_order_by() + + for_update = {'read':'read','update':True,'update_nowait':'nowait'}.get(lockmode, False) if self.mapper.single and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None: whereclause = sql.and_(whereclause, self.mapper.polymorphic_on==self.mapper.polymorphic_identity) @@ -349,7 +359,7 @@ class Query(object): crit = [] for i in range(0, len(self.table.primary_key)): crit.append(s3.primary_key[i] == self.table.primary_key[i]) - statement = sql.select([], sql.and_(*crit), from_obj=[self.table], use_labels=True) + statement = sql.select([], sql.and_(*crit), from_obj=[self.table], use_labels=True, for_update=for_update) # raise "OK statement", str(statement) # now for the order by, convert the columns to their corresponding columns @@ -364,7 +374,7 @@ class Query(object): statement.order_by(*util.to_list(order_by)) else: from_obj.append(self.table) - statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, **kwargs) + statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **kwargs) if order_by: statement.order_by(*util.to_list(order_by)) # for a DISTINCT query, you need the columns explicitly specified in order |
