summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--doc/build/changelog/unreleased_14/5785.rst10
-rw-r--r--lib/sqlalchemy/sql/compiler.py4
-rw-r--r--lib/sqlalchemy/sql/elements.py25
-rw-r--r--lib/sqlalchemy/sql/selectable.py4
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py5
-rw-r--r--test/sql/test_values.py129
6 files changed, 165 insertions, 12 deletions
diff --git a/doc/build/changelog/unreleased_14/5785.rst b/doc/build/changelog/unreleased_14/5785.rst
new file mode 100644
index 000000000..2e07d2da3
--- /dev/null
+++ b/doc/build/changelog/unreleased_14/5785.rst
@@ -0,0 +1,10 @@
+.. change::
+ :tags: bug, sql
+ :tickets: 5785
+
+ Fixed issue in new :class:`_sql.Values` construct where passing tuples of
+ objects would fall back to per-value type detection rather than making use
+ of the :class:`_schema.Column` objects passed directly to
+ :class:`_sql.Values` that tells SQLAlchemy what the expected type is. This
+ would lead to issues for objects such as enumerations and numpy strings
+ that are not actually necessary since the expected type is given. \ No newline at end of file
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 46f111d2c..a734bb582 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -2610,7 +2610,9 @@ class SQLCompiler(Compiled):
v = "VALUES %s" % ", ".join(
self.process(
- elements.Tuple(*elem).self_group(),
+ elements.Tuple(
+ types=element._column_types, *elem
+ ).self_group(),
literal_binds=element.literal_binds,
)
for chunk in element._data
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index ab8701dd6..75c1fc1bf 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -2497,11 +2497,28 @@ class Tuple(ClauseList, ColumnElement):
"""
sqltypes = util.preloaded.sql_sqltypes
- clauses = [
- coercions.expect(roles.ExpressionElementRole, c) for c in clauses
- ]
- self.type = sqltypes.TupleType(*[arg.type for arg in clauses])
+ types = kw.pop("types", None)
+ if types is None:
+ clauses = [
+ coercions.expect(roles.ExpressionElementRole, c)
+ for c in clauses
+ ]
+ else:
+ if len(types) != len(clauses):
+ raise exc.ArgumentError(
+ "Wrong number of elements for %d-tuple: %r "
+ % (len(types), clauses)
+ )
+ clauses = [
+ coercions.expect(
+ roles.ExpressionElementRole,
+ c,
+ type_=typ if not typ._isnull else None,
+ )
+ for typ, c in zip(types, clauses)
+ ]
+ self.type = sqltypes.TupleType(*[arg.type for arg in clauses])
super(Tuple, self).__init__(*clauses, **kw)
@property
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index d60afdbac..b49fe92df 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -2397,6 +2397,10 @@ class Values(Generative, FromClause):
self.literal_binds = kw.pop("literal_binds", False)
self.named_with_column = self.name is not None
+ @property
+ def _column_types(self):
+ return [col.type for col in self._column_args]
+
@_generative
def alias(self, name, **kw):
"""Return a new :class:`_expression.Values`
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 581573d17..09c7388ab 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -3074,7 +3074,9 @@ class NullType(TypeEngine):
def literal_processor(self, dialect):
def process(value):
- return "NULL"
+ raise exc.CompileError(
+ "Don't know how to render literal SQL value: %r" % value
+ )
return process
@@ -3131,6 +3133,7 @@ else:
_type_map[unicode] = Unicode() # noqa
_type_map[str] = String()
+
_type_map_get = _type_map.get
diff --git a/test/sql/test_values.py b/test/sql/test_values.py
index 1e4f22442..43e8f8531 100644
--- a/test/sql/test_values.py
+++ b/test/sql/test_values.py
@@ -1,6 +1,8 @@
from sqlalchemy import alias
from sqlalchemy import Column
from sqlalchemy import column
+from sqlalchemy import Enum
+from sqlalchemy import exc
from sqlalchemy import ForeignKey
from sqlalchemy import Integer
from sqlalchemy import String
@@ -12,7 +14,9 @@ from sqlalchemy.sql import select
from sqlalchemy.sql import Values
from sqlalchemy.sql.compiler import FROM_LINTING
from sqlalchemy.testing import AssertsCompiledSQL
+from sqlalchemy.testing import expect_raises_message
from sqlalchemy.testing import fixtures
+from sqlalchemy.util import OrderedDict
class ValuesTest(fixtures.TablesTest, AssertsCompiledSQL):
@@ -52,34 +56,104 @@ class ValuesTest(fixtures.TablesTest, AssertsCompiledSQL):
Column("book_weight", Integer),
)
+ def test_wrong_number_of_elements(self):
+ v1 = Values(
+ column("CaseSensitive", Integer),
+ column("has spaces", String),
+ name="Spaces and Cases",
+ ).data([(1, "textA", 99), (2, "textB", 88)])
+
+ with expect_raises_message(
+ exc.ArgumentError,
+ r"Wrong number of elements for 2-tuple: \(1, 'textA', 99\)",
+ ):
+ str(v1)
+
def test_column_quoting(self):
v1 = Values(
column("CaseSensitive", Integer),
column("has spaces", String),
+ column("number", Integer),
name="Spaces and Cases",
).data([(1, "textA", 99), (2, "textB", 88)])
self.assert_compile(
select(v1),
'SELECT "Spaces and Cases"."CaseSensitive", '
- '"Spaces and Cases"."has spaces" FROM '
+ '"Spaces and Cases"."has spaces", "Spaces and Cases".number FROM '
"(VALUES (:param_1, :param_2, :param_3), "
"(:param_4, :param_5, :param_6)) "
- 'AS "Spaces and Cases" ("CaseSensitive", "has spaces")',
+ 'AS "Spaces and Cases" ("CaseSensitive", "has spaces", number)',
)
@testing.fixture
def literal_parameter_fixture(self):
- def go(literal_binds):
- return Values(
+ def go(literal_binds, omit=None):
+ cols = [
column("mykey", Integer),
column("mytext", String),
column("myint", Integer),
- name="myvalues",
- literal_binds=literal_binds,
+ ]
+ if omit:
+ for idx in omit:
+ cols[idx] = column(cols[idx].name)
+
+ return Values(
+ *cols, name="myvalues", literal_binds=literal_binds
).data([(1, "textA", 99), (2, "textB", 88)])
return go
+ @testing.fixture
+ def tricky_types_parameter_fixture(self):
+ class SomeEnum(object):
+ # Implements PEP 435 in the minimal fashion needed by SQLAlchemy
+ __members__ = OrderedDict()
+
+ def __init__(self, name, value, alias=None):
+ self.name = name
+ self.value = value
+ self.__members__[name] = self
+ setattr(self.__class__, name, self)
+ if alias:
+ self.__members__[alias] = self
+ setattr(self.__class__, alias, self)
+
+ one = SomeEnum("one", 1)
+ two = SomeEnum("two", 2)
+
+ class MumPyString(str):
+ """some kind of string, can't imagine where such a thing might
+ be found
+
+ """
+
+ class MumPyNumber(int):
+ """some kind of int, can't imagine where such a thing might
+ be found
+
+ """
+
+ def go(literal_binds, omit=None):
+ cols = [
+ column("mykey", Integer),
+ column("mytext", String),
+ column("myenum", Enum(SomeEnum)),
+ ]
+ if omit:
+ for idx in omit:
+ cols[idx] = column(cols[idx].name)
+
+ return Values(
+ *cols, name="myvalues", literal_binds=literal_binds
+ ).data(
+ [
+ (MumPyNumber(1), MumPyString("textA"), one),
+ (MumPyNumber(2), MumPyString("textB"), two),
+ ]
+ )
+
+ return go
+
def test_bound_parameters(self, literal_parameter_fixture):
literal_parameter_fixture = literal_parameter_fixture(False)
@@ -114,6 +188,49 @@ class ValuesTest(fixtures.TablesTest, AssertsCompiledSQL):
checkparams={},
)
+ def test_literal_parameters_not_every_type_given(
+ self, literal_parameter_fixture
+ ):
+ literal_parameter_fixture = literal_parameter_fixture(True, omit=(1,))
+
+ stmt = select(literal_parameter_fixture)
+
+ self.assert_compile(
+ stmt,
+ "SELECT myvalues.mykey, myvalues.mytext, myvalues.myint FROM "
+ "(VALUES (1, 'textA', 99), (2, 'textB', 88)"
+ ") AS myvalues (mykey, mytext, myint)",
+ checkparams={},
+ )
+
+ def test_use_cols_tricky_not_every_type_given(
+ self, tricky_types_parameter_fixture
+ ):
+ literal_parameter_fixture = tricky_types_parameter_fixture(
+ True, omit=(1,)
+ )
+
+ stmt = select(literal_parameter_fixture)
+
+ with expect_raises_message(
+ exc.CompileError,
+ "Don't know how to render literal SQL value: 'textA'",
+ ):
+ str(stmt)
+
+ def test_use_cols_for_types(self, tricky_types_parameter_fixture):
+ literal_parameter_fixture = tricky_types_parameter_fixture(True)
+
+ stmt = select(literal_parameter_fixture)
+
+ self.assert_compile(
+ stmt,
+ "SELECT myvalues.mykey, myvalues.mytext, myvalues.myenum FROM "
+ "(VALUES (1, 'textA', 'one'), (2, 'textB', 'two')"
+ ") AS myvalues (mykey, mytext, myenum)",
+ checkparams={},
+ )
+
def test_with_join_unnamed(self):
people = self.tables.people
values = Values(