diff options
| -rw-r--r-- | doc/build/changelog/unreleased_14/5785.rst | 10 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/elements.py | 25 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/selectable.py | 4 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 5 | ||||
| -rw-r--r-- | test/sql/test_values.py | 129 |
6 files changed, 165 insertions, 12 deletions
diff --git a/doc/build/changelog/unreleased_14/5785.rst b/doc/build/changelog/unreleased_14/5785.rst new file mode 100644 index 000000000..2e07d2da3 --- /dev/null +++ b/doc/build/changelog/unreleased_14/5785.rst @@ -0,0 +1,10 @@ +.. change:: + :tags: bug, sql + :tickets: 5785 + + Fixed issue in new :class:`_sql.Values` construct where passing tuples of + objects would fall back to per-value type detection rather than making use + of the :class:`_schema.Column` objects passed directly to + :class:`_sql.Values` that tells SQLAlchemy what the expected type is. This + would lead to issues for objects such as enumerations and numpy strings + that are not actually necessary since the expected type is given.
\ No newline at end of file diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 46f111d2c..a734bb582 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -2610,7 +2610,9 @@ class SQLCompiler(Compiled): v = "VALUES %s" % ", ".join( self.process( - elements.Tuple(*elem).self_group(), + elements.Tuple( + types=element._column_types, *elem + ).self_group(), literal_binds=element.literal_binds, ) for chunk in element._data diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py index ab8701dd6..75c1fc1bf 100644 --- a/lib/sqlalchemy/sql/elements.py +++ b/lib/sqlalchemy/sql/elements.py @@ -2497,11 +2497,28 @@ class Tuple(ClauseList, ColumnElement): """ sqltypes = util.preloaded.sql_sqltypes - clauses = [ - coercions.expect(roles.ExpressionElementRole, c) for c in clauses - ] - self.type = sqltypes.TupleType(*[arg.type for arg in clauses]) + types = kw.pop("types", None) + if types is None: + clauses = [ + coercions.expect(roles.ExpressionElementRole, c) + for c in clauses + ] + else: + if len(types) != len(clauses): + raise exc.ArgumentError( + "Wrong number of elements for %d-tuple: %r " + % (len(types), clauses) + ) + clauses = [ + coercions.expect( + roles.ExpressionElementRole, + c, + type_=typ if not typ._isnull else None, + ) + for typ, c in zip(types, clauses) + ] + self.type = sqltypes.TupleType(*[arg.type for arg in clauses]) super(Tuple, self).__init__(*clauses, **kw) @property diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py index d60afdbac..b49fe92df 100644 --- a/lib/sqlalchemy/sql/selectable.py +++ b/lib/sqlalchemy/sql/selectable.py @@ -2397,6 +2397,10 @@ class Values(Generative, FromClause): self.literal_binds = kw.pop("literal_binds", False) self.named_with_column = self.name is not None + @property + def _column_types(self): + return [col.type for col in self._column_args] + @_generative def alias(self, name, **kw): """Return a new :class:`_expression.Values` diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 581573d17..09c7388ab 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -3074,7 +3074,9 @@ class NullType(TypeEngine): def literal_processor(self, dialect): def process(value): - return "NULL" + raise exc.CompileError( + "Don't know how to render literal SQL value: %r" % value + ) return process @@ -3131,6 +3133,7 @@ else: _type_map[unicode] = Unicode() # noqa _type_map[str] = String() + _type_map_get = _type_map.get diff --git a/test/sql/test_values.py b/test/sql/test_values.py index 1e4f22442..43e8f8531 100644 --- a/test/sql/test_values.py +++ b/test/sql/test_values.py @@ -1,6 +1,8 @@ from sqlalchemy import alias from sqlalchemy import Column from sqlalchemy import column +from sqlalchemy import Enum +from sqlalchemy import exc from sqlalchemy import ForeignKey from sqlalchemy import Integer from sqlalchemy import String @@ -12,7 +14,9 @@ from sqlalchemy.sql import select from sqlalchemy.sql import Values from sqlalchemy.sql.compiler import FROM_LINTING from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures +from sqlalchemy.util import OrderedDict class ValuesTest(fixtures.TablesTest, AssertsCompiledSQL): @@ -52,34 +56,104 @@ class ValuesTest(fixtures.TablesTest, AssertsCompiledSQL): Column("book_weight", Integer), ) + def test_wrong_number_of_elements(self): + v1 = Values( + column("CaseSensitive", Integer), + column("has spaces", String), + name="Spaces and Cases", + ).data([(1, "textA", 99), (2, "textB", 88)]) + + with expect_raises_message( + exc.ArgumentError, + r"Wrong number of elements for 2-tuple: \(1, 'textA', 99\)", + ): + str(v1) + def test_column_quoting(self): v1 = Values( column("CaseSensitive", Integer), column("has spaces", String), + column("number", Integer), name="Spaces and Cases", ).data([(1, "textA", 99), (2, "textB", 88)]) self.assert_compile( select(v1), 'SELECT "Spaces and Cases"."CaseSensitive", ' - '"Spaces and Cases"."has spaces" FROM ' + '"Spaces and Cases"."has spaces", "Spaces and Cases".number FROM ' "(VALUES (:param_1, :param_2, :param_3), " "(:param_4, :param_5, :param_6)) " - 'AS "Spaces and Cases" ("CaseSensitive", "has spaces")', + 'AS "Spaces and Cases" ("CaseSensitive", "has spaces", number)', ) @testing.fixture def literal_parameter_fixture(self): - def go(literal_binds): - return Values( + def go(literal_binds, omit=None): + cols = [ column("mykey", Integer), column("mytext", String), column("myint", Integer), - name="myvalues", - literal_binds=literal_binds, + ] + if omit: + for idx in omit: + cols[idx] = column(cols[idx].name) + + return Values( + *cols, name="myvalues", literal_binds=literal_binds ).data([(1, "textA", 99), (2, "textB", 88)]) return go + @testing.fixture + def tricky_types_parameter_fixture(self): + class SomeEnum(object): + # Implements PEP 435 in the minimal fashion needed by SQLAlchemy + __members__ = OrderedDict() + + def __init__(self, name, value, alias=None): + self.name = name + self.value = value + self.__members__[name] = self + setattr(self.__class__, name, self) + if alias: + self.__members__[alias] = self + setattr(self.__class__, alias, self) + + one = SomeEnum("one", 1) + two = SomeEnum("two", 2) + + class MumPyString(str): + """some kind of string, can't imagine where such a thing might + be found + + """ + + class MumPyNumber(int): + """some kind of int, can't imagine where such a thing might + be found + + """ + + def go(literal_binds, omit=None): + cols = [ + column("mykey", Integer), + column("mytext", String), + column("myenum", Enum(SomeEnum)), + ] + if omit: + for idx in omit: + cols[idx] = column(cols[idx].name) + + return Values( + *cols, name="myvalues", literal_binds=literal_binds + ).data( + [ + (MumPyNumber(1), MumPyString("textA"), one), + (MumPyNumber(2), MumPyString("textB"), two), + ] + ) + + return go + def test_bound_parameters(self, literal_parameter_fixture): literal_parameter_fixture = literal_parameter_fixture(False) @@ -114,6 +188,49 @@ class ValuesTest(fixtures.TablesTest, AssertsCompiledSQL): checkparams={}, ) + def test_literal_parameters_not_every_type_given( + self, literal_parameter_fixture + ): + literal_parameter_fixture = literal_parameter_fixture(True, omit=(1,)) + + stmt = select(literal_parameter_fixture) + + self.assert_compile( + stmt, + "SELECT myvalues.mykey, myvalues.mytext, myvalues.myint FROM " + "(VALUES (1, 'textA', 99), (2, 'textB', 88)" + ") AS myvalues (mykey, mytext, myint)", + checkparams={}, + ) + + def test_use_cols_tricky_not_every_type_given( + self, tricky_types_parameter_fixture + ): + literal_parameter_fixture = tricky_types_parameter_fixture( + True, omit=(1,) + ) + + stmt = select(literal_parameter_fixture) + + with expect_raises_message( + exc.CompileError, + "Don't know how to render literal SQL value: 'textA'", + ): + str(stmt) + + def test_use_cols_for_types(self, tricky_types_parameter_fixture): + literal_parameter_fixture = tricky_types_parameter_fixture(True) + + stmt = select(literal_parameter_fixture) + + self.assert_compile( + stmt, + "SELECT myvalues.mykey, myvalues.mytext, myvalues.myenum FROM " + "(VALUES (1, 'textA', 'one'), (2, 'textB', 'two')" + ") AS myvalues (mykey, mytext, myenum)", + checkparams={}, + ) + def test_with_join_unnamed(self): people = self.tables.people values = Values( |
