summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_13/4860.rst6
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py31
-rw-r--r--lib/sqlalchemy/engine/default.py3
-rw-r--r--test/dialect/mysql/test_compiler.py19
-rw-r--r--test/dialect/mysql/test_for_update.py197
-rw-r--r--test/dialect/postgresql/test_compiler.py66
-rw-r--r--test/requirements.py4
7 files changed, 319 insertions, 7 deletions
diff --git a/doc/build/changelog/unreleased_13/4860.rst b/doc/build/changelog/unreleased_13/4860.rst
new file mode 100644
index 000000000..b526ce31e
--- /dev/null
+++ b/doc/build/changelog/unreleased_13/4860.rst
@@ -0,0 +1,6 @@
+.. change::
+ :tags: usecase, mysql
+ :tickets: 4860
+
+ Implemented row-level locking support for mysql. Pull request courtesy
+ Quentin Somerville. \ No newline at end of file
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index dca7b9a00..d009d656e 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -808,6 +808,7 @@ from ...sql import coercions
from ...sql import compiler
from ...sql import elements
from ...sql import roles
+from ...sql import util as sql_util
from ...types import BINARY
from ...types import BLOB
from ...types import BOOLEAN
@@ -1494,9 +1495,28 @@ class MySQLCompiler(compiler.SQLCompiler):
def for_update_clause(self, select, **kw):
if select._for_update_arg.read:
- return " LOCK IN SHARE MODE"
+ tmp = " LOCK IN SHARE MODE"
else:
- return " FOR UPDATE"
+ tmp = " FOR UPDATE"
+
+ if select._for_update_arg.of and self.dialect.supports_for_update_of:
+
+ tables = util.OrderedSet()
+ for c in select._for_update_arg.of:
+ tables.update(sql_util.surface_selectables_only(c))
+
+ tmp += " OF " + ", ".join(
+ self.process(table, ashint=True, use_schema=False, **kw)
+ for table in tables
+ )
+
+ if select._for_update_arg.nowait:
+ tmp += " NOWAIT"
+
+ if select._for_update_arg.skip_locked and self.dialect._is_mysql:
+ tmp += " SKIP LOCKED"
+
+ return tmp
def limit_clause(self, select, **kw):
# MySQL supports:
@@ -2211,6 +2231,9 @@ class MySQLDialect(default.DefaultDialect):
sequences_optional = True
+ supports_for_update_of = False # default for MySQL ...
+ # ... may be updated to True for MySQL 8+ in initialize()
+
supports_sane_rowcount = True
supports_sane_multi_rowcount = False
supports_multivalues_insert = True
@@ -2526,6 +2549,10 @@ class MySQLDialect(default.DefaultDialect):
self._is_mariadb and self.server_version_info >= (10, 3)
)
+ self.supports_for_update_of = (
+ self._is_mysql and self.server_version_info >= (8,)
+ )
+
self._needs_correct_for_88718_96365 = (
not self._is_mariadb and self.server_version_info >= (8,)
)
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 52651aa2d..e30daaeb8 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -128,6 +128,9 @@ class DefaultDialect(interfaces.Dialect):
supports_server_side_cursors = False
+ # extra record-level locking features (#4860)
+ supports_for_update_of = False
+
server_version_info = None
default_schema_name = None
diff --git a/test/dialect/mysql/test_compiler.py b/test/dialect/mysql/test_compiler.py
index 4e6199c6f..167460cba 100644
--- a/test/dialect/mysql/test_compiler.py
+++ b/test/dialect/mysql/test_compiler.py
@@ -353,21 +353,30 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
expr = literal("x", type_=String) + literal("y", type_=String)
self.assert_compile(expr, "concat('x', 'y')", literal_binds=True)
- def test_for_update(self):
+ def test_mariadb_for_update(self):
+ dialect = mysql.dialect()
+ dialect.server_version_info = (10, 1, 1, "MariaDB")
+
table1 = table(
"mytable", column("myid"), column("name"), column("description")
)
self.assert_compile(
- table1.select(table1.c.myid == 7).with_for_update(),
+ table1.select(table1.c.myid == 7).with_for_update(of=table1),
"SELECT mytable.myid, mytable.name, mytable.description "
- "FROM mytable WHERE mytable.myid = %s FOR UPDATE",
+ "FROM mytable WHERE mytable.myid = %s "
+ "FOR UPDATE",
+ dialect=dialect,
)
self.assert_compile(
- table1.select(table1.c.myid == 7).with_for_update(read=True),
+ table1.select(table1.c.myid == 7).with_for_update(
+ skip_locked=True
+ ),
"SELECT mytable.myid, mytable.name, mytable.description "
- "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE",
+ "FROM mytable WHERE mytable.myid = %s "
+ "FOR UPDATE",
+ dialect=dialect,
)
def test_delete_extra_froms(self):
diff --git a/test/dialect/mysql/test_for_update.py b/test/dialect/mysql/test_for_update.py
index 5897a094d..2c247a5c0 100644
--- a/test/dialect/mysql/test_for_update.py
+++ b/test/dialect/mysql/test_for_update.py
@@ -11,9 +11,13 @@ from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import testing
from sqlalchemy import update
+from sqlalchemy.dialects.mysql import base as mysql
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import relationship
from sqlalchemy.orm import Session
+from sqlalchemy.sql import column
+from sqlalchemy.sql import table
+from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import fixtures
@@ -160,3 +164,196 @@ class MySQLForUpdateLockingTest(fixtures.DeclarativeMappedTest):
# no subquery, should be locked
self._assert_a_is_locked(True)
self._assert_b_is_locked(True)
+
+
+class MySQLForUpdateCompileTest(fixtures.TestBase, AssertsCompiledSQL):
+ __dialect__ = mysql.dialect()
+
+ table1 = table(
+ "mytable", column("myid"), column("name"), column("description")
+ )
+ table2 = table("table2", column("mytable_id"))
+ join = table2.join(table1, table2.c.mytable_id == table1.c.myid)
+ for_update_of_dialect = mysql.dialect()
+ for_update_of_dialect.server_version_info = (8, 0, 0)
+ for_update_of_dialect.supports_for_update_of = True
+
+ def test_for_update_basic(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s FOR UPDATE",
+ )
+
+ def test_for_update_read(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ read=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s LOCK IN SHARE MODE",
+ )
+
+ def test_for_update_skip_locked(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ skip_locked=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "FOR UPDATE SKIP LOCKED",
+ )
+
+ def test_for_update_read_and_skip_locked(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ read=True, skip_locked=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "LOCK IN SHARE MODE SKIP LOCKED",
+ )
+
+ def test_for_update_nowait(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ nowait=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "FOR UPDATE NOWAIT",
+ )
+
+ def test_for_update_read_and_nowait(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ read=True, nowait=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "LOCK IN SHARE MODE NOWAIT",
+ )
+
+ def test_for_update_of_nowait(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ of=self.table1, nowait=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "FOR UPDATE OF mytable NOWAIT",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_basic(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ of=self.table1
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "FOR UPDATE OF mytable",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_skip_locked(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ of=self.table1, skip_locked=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "FOR UPDATE OF mytable SKIP LOCKED",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_join_one(self):
+ self.assert_compile(
+ self.join.select(self.table2.c.mytable_id == 7).with_for_update(
+ of=[self.join]
+ ),
+ "SELECT table2.mytable_id, "
+ "mytable.myid, mytable.name, mytable.description "
+ "FROM table2 "
+ "INNER JOIN mytable ON table2.mytable_id = mytable.myid "
+ "WHERE table2.mytable_id = %s "
+ "FOR UPDATE OF mytable, table2",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_column_list_aliased(self):
+ ta = self.table1.alias()
+ self.assert_compile(
+ ta.select(ta.c.myid == 7).with_for_update(
+ of=[ta.c.myid, ta.c.name]
+ ),
+ "SELECT mytable_1.myid, mytable_1.name, mytable_1.description "
+ "FROM mytable AS mytable_1 "
+ "WHERE mytable_1.myid = %s FOR UPDATE OF mytable_1",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_join_aliased(self):
+ ta = self.table1.alias()
+ alias_join = self.table2.join(
+ ta, self.table2.c.mytable_id == ta.c.myid
+ )
+ self.assert_compile(
+ alias_join.select(self.table2.c.mytable_id == 7).with_for_update(
+ of=[alias_join]
+ ),
+ "SELECT table2.mytable_id, "
+ "mytable_1.myid, mytable_1.name, mytable_1.description "
+ "FROM table2 "
+ "INNER JOIN mytable AS mytable_1 "
+ "ON table2.mytable_id = mytable_1.myid "
+ "WHERE table2.mytable_id = %s "
+ "FOR UPDATE OF mytable_1, table2",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_read_nowait(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ read=True, of=self.table1, nowait=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "LOCK IN SHARE MODE OF mytable NOWAIT",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_read_skip_locked(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ read=True, of=self.table1, skip_locked=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "LOCK IN SHARE MODE OF mytable SKIP LOCKED",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_read_nowait_column_list(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ read=True,
+ of=[self.table1.c.myid, self.table1.c.name],
+ nowait=True,
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "LOCK IN SHARE MODE OF mytable NOWAIT",
+ dialect=self.for_update_of_dialect,
+ )
+
+ def test_for_update_of_read(self):
+ self.assert_compile(
+ self.table1.select(self.table1.c.myid == 7).with_for_update(
+ read=True, of=self.table1
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %s "
+ "LOCK IN SHARE MODE OF mytable",
+ dialect=self.for_update_of_dialect,
+ )
diff --git a/test/dialect/postgresql/test_compiler.py b/test/dialect/postgresql/test_compiler.py
index 4cc9c837d..c707137a8 100644
--- a/test/dialect/postgresql/test_compiler.py
+++ b/test/dialect/postgresql/test_compiler.py
@@ -952,6 +952,24 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
self.assert_compile(
table1.select(table1.c.myid == 7).with_for_update(
+ key_share=True, nowait=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR NO KEY UPDATE NOWAIT",
+ )
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(
+ key_share=True, read=True, nowait=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR KEY SHARE NOWAIT",
+ )
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(
read=True, skip_locked=True
),
"SELECT mytable.myid, mytable.name, mytable.description "
@@ -979,6 +997,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
self.assert_compile(
table1.select(table1.c.myid == 7).with_for_update(
+ key_share=True, read=True, nowait=True, of=table1
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR KEY SHARE OF mytable NOWAIT",
+ )
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(
read=True, nowait=True, of=table1.c.myid
),
"SELECT mytable.myid, mytable.name, mytable.description "
@@ -997,6 +1024,27 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
self.assert_compile(
table1.select(table1.c.myid == 7).with_for_update(
+ read=True,
+ skip_locked=True,
+ of=[table1.c.myid, table1.c.name],
+ key_share=True,
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR KEY SHARE OF mytable SKIP LOCKED",
+ )
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(
+ skip_locked=True, of=[table1.c.myid, table1.c.name]
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR UPDATE OF mytable SKIP LOCKED",
+ )
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(
read=True, skip_locked=True, of=[table1.c.myid, table1.c.name]
),
"SELECT mytable.myid, mytable.name, mytable.description "
@@ -1060,6 +1108,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
self.assert_compile(
table1.select(table1.c.myid == 7).with_for_update(
+ read=True, of=table1
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR SHARE OF mytable",
+ )
+
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(
read=True, key_share=True, skip_locked=True
),
"SELECT mytable.myid, mytable.name, mytable.description "
@@ -1067,6 +1124,15 @@ class CompileTest(fixtures.TestBase, AssertsCompiledSQL):
"FOR KEY SHARE SKIP LOCKED",
)
+ self.assert_compile(
+ table1.select(table1.c.myid == 7).with_for_update(
+ key_share=True, skip_locked=True
+ ),
+ "SELECT mytable.myid, mytable.name, mytable.description "
+ "FROM mytable WHERE mytable.myid = %(myid_1)s "
+ "FOR NO KEY UPDATE SKIP LOCKED",
+ )
+
ta = table1.alias()
self.assert_compile(
ta.select(ta.c.myid == 7).with_for_update(
diff --git a/test/requirements.py b/test/requirements.py
index 189aecb5e..969b4ea83 100644
--- a/test/requirements.py
+++ b/test/requirements.py
@@ -1626,3 +1626,7 @@ class DefaultRequirements(SuiteRequirements):
def supports_distinct_on(self):
"""If a backend supports the DISTINCT ON in a select"""
return only_if(["postgresql"])
+
+ @property
+ def supports_for_update_of(self):
+ return only_if(lambda config: config.db.dialect.supports_for_update_of)