summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2018-04-04 13:36:28 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2018-04-04 16:46:16 -0400
commitb4eb29253cb29a069973503f36d1103d4a18311c (patch)
treed74797804981a1234b993569fa426e78ba7a6e00
parent9f986ce10c6755af3f347a56f9ea03e0e2c5943e (diff)
downloadsqlalchemy-b4eb29253cb29a069973503f36d1103d4a18311c.tar.gz
Ensure all visit_sequence accepts **kw args
Fixed issue where the compilation of an INSERT statement with the "literal_binds" option that also uses an explicit sequence and "inline" generation, as on Postgresql and Oracle, would fail to accommodate the extra keyword argument within the sequence processing routine. Change-Id: Ibdab7d340aea7429a210c9535ccf1a3e85f074fb Fixes: #4231
-rw-r--r--doc/build/changelog/unreleased_12/4231.rst9
-rw-r--r--lib/sqlalchemy/dialects/firebird/base.py2
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py2
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py2
-rw-r--r--lib/sqlalchemy/sql/compiler.py2
-rw-r--r--lib/sqlalchemy/testing/suite/test_sequence.py24
-rw-r--r--test/sql/test_compiler.py15
7 files changed, 50 insertions, 6 deletions
diff --git a/doc/build/changelog/unreleased_12/4231.rst b/doc/build/changelog/unreleased_12/4231.rst
new file mode 100644
index 000000000..47e70ef02
--- /dev/null
+++ b/doc/build/changelog/unreleased_12/4231.rst
@@ -0,0 +1,9 @@
+.. change::
+ :tags: bug, sql
+ :tickets: 4231
+ :versions: 1.3.0b1
+
+ Fixed issue where the compilation of an INSERT statement with the
+ "literal_binds" option that also uses an explicit sequence and "inline"
+ generation, as on Postgresql and Oracle, would fail to accommodate the
+ extra keyword argument within the sequence processing routine.
diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py
index 335163f15..7b470c189 100644
--- a/lib/sqlalchemy/dialects/firebird/base.py
+++ b/lib/sqlalchemy/dialects/firebird/base.py
@@ -291,7 +291,7 @@ class FBCompiler(sql.compiler.SQLCompiler):
def default_from(self):
return " FROM rdb$database"
- def visit_sequence(self, seq):
+ def visit_sequence(self, seq, **kw):
return "gen_id(%s, 1)" % self.preparer.format_sequence(seq)
def get_select_precolumns(self, select, **kw):
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index 3970a181c..44ab9e3bb 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -767,7 +767,7 @@ class OracleCompiler(compiler.SQLCompiler):
def visit_outer_join_column(self, vc, **kw):
return self.process(vc.column, **kw) + "(+)"
- def visit_sequence(self, seq):
+ def visit_sequence(self, seq, **kw):
return (self.dialect.identifier_preparer.format_sequence(seq) +
".nextval")
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index c5b0db6ce..0160239b7 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -1489,7 +1489,7 @@ class PGCompiler(compiler.SQLCompiler):
value = value.replace('\\', '\\\\')
return value
- def visit_sequence(self, seq):
+ def visit_sequence(self, seq, **kw):
return "nextval('%s')" % self.preparer.format_sequence(seq)
def limit_clause(self, select, **kw):
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 6c7e6145d..a442c65fd 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -934,7 +934,7 @@ class SQLCompiler(Compiled):
def visit_next_value_func(self, next_value, **kw):
return self.visit_sequence(next_value.sequence)
- def visit_sequence(self, sequence):
+ def visit_sequence(self, sequence, **kw):
raise NotImplementedError(
"Dialect '%s' does not support sequence increments." %
self.dialect.name
diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py
index b2d52f27c..f1c00de6b 100644
--- a/lib/sqlalchemy/testing/suite/test_sequence.py
+++ b/lib/sqlalchemy/testing/suite/test_sequence.py
@@ -3,7 +3,7 @@ from ..config import requirements
from ..assertions import eq_
from ... import testing
-from ... import Integer, String, Sequence, schema
+from ... import Integer, String, Sequence, schema, MetaData
from ..schema import Table, Column
@@ -71,6 +71,28 @@ class SequenceTest(fixtures.TablesTest):
)
+class SequenceCompilerTest(testing.AssertsCompiledSQL, fixtures.TestBase):
+ __requires__ = ('sequences',)
+ __backend__ = True
+
+ def test_literal_binds_inline_compile(self):
+ table = Table(
+ 'x', MetaData(),
+ Column('y', Integer, Sequence('y_seq')),
+ Column('q', Integer))
+
+ stmt = table.insert().values(q=5)
+
+ seq_nextval = testing.db.dialect.statement_compiler(
+ statement=None, dialect=testing.db.dialect).visit_sequence(
+ Sequence("y_seq"))
+ self.assert_compile(
+ stmt,
+ "INSERT INTO x (y, q) VALUES (%s, 5)" % (seq_nextval, ),
+ literal_binds=True,
+ dialect=testing.db.dialect)
+
+
class HasSequenceTest(fixtures.TestBase):
__requires__ = 'sequences',
__backend__ = True
diff --git a/test/sql/test_compiler.py b/test/sql/test_compiler.py
index 25eb2b24b..0ef19e0cb 100644
--- a/test/sql/test_compiler.py
+++ b/test/sql/test_compiler.py
@@ -19,7 +19,7 @@ from sqlalchemy import Integer, String, MetaData, Table, Column, select, \
literal, and_, null, type_coerce, alias, or_, literal_column,\
Float, TIMESTAMP, Numeric, Date, Text, union, except_,\
intersect, union_all, Boolean, distinct, join, outerjoin, asc, desc,\
- over, subquery, case, true, CheckConstraint
+ over, subquery, case, true, CheckConstraint, Sequence
import decimal
from sqlalchemy.util import u
from sqlalchemy import exc, sql, util, types, schema
@@ -2955,6 +2955,19 @@ class CRUDTest(fixtures.TestBase, AssertsCompiledSQL):
"INSERT INTO mytable (myid, name) VALUES (3, 'jack')",
literal_binds=True)
+ def test_insert_literal_binds_sequence_notimplemented(self):
+ table = Table('x', MetaData(), Column('y', Integer, Sequence('y_seq')))
+ dialect = default.DefaultDialect()
+ dialect.supports_sequences = True
+
+ stmt = table.insert().values(myid=3, name='jack')
+
+ assert_raises(
+ NotImplementedError,
+ stmt.compile,
+ compile_kwargs=dict(literal_binds=True), dialect=dialect
+ )
+
def test_update_literal_binds(self):
stmt = table1.update().values(name='jack').\
where(table1.c.name == 'jill')