summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2006-09-10 23:52:04 +0000
committerMike Bayer <mike_mp@zzzcomputing.com>2006-09-10 23:52:04 +0000
commit47d8b03b14145997fc0936bd674363f0e213f019 (patch)
tree64e9356d7c5f435b68d7deaa8b41c6a4cdc5c521 /lib/sqlalchemy
parent287bf217958fbccb07cffafcc4481a2b6c7f2784 (diff)
downloadsqlalchemy-47d8b03b14145997fc0936bd674363f0e213f019.tar.gz
- changed "for_update" parameter to accept False/True/"nowait"
and "read", the latter two of which are interpreted only by Oracle and Mysql [ticket:292] - added "lockmode" argument to base Query select/get functions, including "with_lockmode" function to get a Query copy that has a default locking mode. Will translate "read"/"update" arguments into a for_update argument on the select side. [ticket:292]
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/ansisql.py14
-rw-r--r--lib/sqlalchemy/databases/mysql.py6
-rw-r--r--lib/sqlalchemy/databases/oracle.py6
-rw-r--r--lib/sqlalchemy/orm/query.py28
4 files changed, 39 insertions, 15 deletions
diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
index c44595f36..66b917c20 100644
--- a/lib/sqlalchemy/ansisql.py
+++ b/lib/sqlalchemy/ansisql.py
@@ -394,13 +394,9 @@ class ANSICompiler(sql.Compiled):
text += " ORDER BY " + order_by
text += self.visit_select_postclauses(select)
-
- if select.for_update:
- text += " FOR UPDATE"
- if select.nowait:
- text += " NOWAIT"
-
+ text += self.for_update_clause(select)
+
if getattr(select, 'parens', False):
self.strings[select] = "(" + text + ")"
else:
@@ -415,6 +411,12 @@ class ANSICompiler(sql.Compiled):
""" called when building a SELECT statement, position is after all other SELECT clauses. Most DB syntaxes put LIMIT/OFFSET here """
return (select.limit or select.offset) and self.limit_clause(select) or ""
+ def for_update_clause(self, select):
+ if select.for_update:
+ return " FOR UPDATE"
+ else:
+ return ""
+
def limit_clause(self, select):
text = ""
if select.limit is not None:
diff --git a/lib/sqlalchemy/databases/mysql.py b/lib/sqlalchemy/databases/mysql.py
index c6d78cf90..4eab9e55c 100644
--- a/lib/sqlalchemy/databases/mysql.py
+++ b/lib/sqlalchemy/databases/mysql.py
@@ -411,6 +411,12 @@ class MySQLCompiler(ansisql.ANSICompiler):
# TODO: put whatever MySQL does for CAST here.
self.strings[cast] = self.strings[cast.clause]
+ def for_update_clause(self, select):
+ if select.for_update == 'read':
+ return ' LOCK IN SHARE MODE'
+ else:
+ return super(MySQLCompiler, self).for_update_clause(select)
+
def limit_clause(self, select):
text = ""
if select.limit is not None:
diff --git a/lib/sqlalchemy/databases/oracle.py b/lib/sqlalchemy/databases/oracle.py
index 5f574338b..5db157cbb 100644
--- a/lib/sqlalchemy/databases/oracle.py
+++ b/lib/sqlalchemy/databases/oracle.py
@@ -402,6 +402,12 @@ class OracleCompiler(ansisql.ANSICompiler):
def limit_clause(self, select):
return ""
+ def for_update_clause(self, select):
+ if select.for_update=="nowait":
+ return " FOR UPDATE NOWAIT"
+ else:
+ return super(OracleCompiler, self).for_update_clause(select)
+
class OracleSchemaGenerator(ansisql.ANSISchemaGenerator):
def get_column_specification(self, column, **kwargs):
colspec = self.preparer.format_column(column)
diff --git a/lib/sqlalchemy/orm/query.py b/lib/sqlalchemy/orm/query.py
index 29cc56761..d35219208 100644
--- a/lib/sqlalchemy/orm/query.py
+++ b/lib/sqlalchemy/orm/query.py
@@ -1,4 +1,4 @@
-# orm/query.py
+ # orm/query.py
# Copyright (C) 2005,2006 Michael Bayer mike_mp@zzzcomputing.com
#
# This module is part of SQLAlchemy and is released under
@@ -12,7 +12,7 @@ import mapper
class Query(object):
"""encapsulates the object-fetching operations provided by Mappers."""
- def __init__(self, class_or_mapper, session=None, entity_name=None, **kwargs):
+ def __init__(self, class_or_mapper, session=None, entity_name=None, lockmode=None, **kwargs):
if isinstance(class_or_mapper, type):
self.mapper = mapper.class_mapper(class_or_mapper, entity_name=entity_name)
else:
@@ -20,6 +20,7 @@ class Query(object):
self.mapper = self.mapper.get_select_mapper().compile()
self.always_refresh = kwargs.pop('always_refresh', self.mapper.always_refresh)
self.order_by = kwargs.pop('order_by', self.mapper.order_by)
+ self.lockmode = lockmode
self.extension = kwargs.pop('extension', self.mapper.extension)
self._session = session
if not hasattr(self.mapper, '_get_clause'):
@@ -67,7 +68,8 @@ class Query(object):
e.g. u = usermapper.get_by(user_name = 'fred')
"""
- x = self.select_whereclause(self.join_by(*args, **params), limit=1)
+ lockmode=params.pop('lockmode', self.lockmode)
+ x = self.select_whereclause(self.join_by(*args, **params), lockmode=lockmode, limit=1)
if x:
return x[0]
else:
@@ -248,7 +250,11 @@ class Query(object):
def options(self, *args, **kwargs):
"""returns a new Query object using the given MapperOptions."""
return self.mapper.options(*args, **kwargs).using(session=self._session)
-
+
+ def with_lockmode(self, mode):
+ """return a new Query object with the specified locking mode."""
+ return Query(self.mapper, self._session, lockmode=mode)
+
def __getattr__(self, key):
if (key.startswith('select_by_')):
key = key[10:]
@@ -270,8 +276,9 @@ class Query(object):
finally:
result.close()
- def _get(self, key, ident=None, reload=False):
- if not reload and not self.always_refresh:
+ def _get(self, key, ident=None, reload=False, lockmode=None):
+ lockmode = lockmode or self.lockmode
+ if not reload and not self.always_refresh and lockmode == None:
try:
return self.session._get(key)
except KeyError:
@@ -293,7 +300,7 @@ class Query(object):
if len(ident) > i + 1:
i += 1
try:
- statement = self.compile(self._get_clause)
+ statement = self.compile(self._get_clause, lockmode=lockmode)
return self._select_statement(statement, params=params, populate_existing=reload)[0]
except IndexError:
return None
@@ -320,11 +327,14 @@ class Query(object):
def compile(self, whereclause = None, **kwargs):
order_by = kwargs.pop('order_by', False)
from_obj = kwargs.pop('from_obj', [])
+ lockmode = kwargs.pop('lockmode', self.lockmode)
if order_by is False:
order_by = self.order_by
if order_by is False:
if self.table.default_order_by() is not None:
order_by = self.table.default_order_by()
+
+ for_update = {'read':'read','update':True,'update_nowait':'nowait'}.get(lockmode, False)
if self.mapper.single and self.mapper.polymorphic_on is not None and self.mapper.polymorphic_identity is not None:
whereclause = sql.and_(whereclause, self.mapper.polymorphic_on==self.mapper.polymorphic_identity)
@@ -349,7 +359,7 @@ class Query(object):
crit = []
for i in range(0, len(self.table.primary_key)):
crit.append(s3.primary_key[i] == self.table.primary_key[i])
- statement = sql.select([], sql.and_(*crit), from_obj=[self.table], use_labels=True)
+ statement = sql.select([], sql.and_(*crit), from_obj=[self.table], use_labels=True, for_update=for_update)
# raise "OK statement", str(statement)
# now for the order by, convert the columns to their corresponding columns
@@ -364,7 +374,7 @@ class Query(object):
statement.order_by(*util.to_list(order_by))
else:
from_obj.append(self.table)
- statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, **kwargs)
+ statement = sql.select([], whereclause, from_obj=from_obj, use_labels=True, for_update=for_update, **kwargs)
if order_by:
statement.order_by(*util.to_list(order_by))
# for a DISTINCT query, you need the columns explicitly specified in order