diff options
| author | mike bayer <mike_mp@zzzcomputing.com> | 2022-06-16 02:30:04 +0000 |
|---|---|---|
| committer | Gerrit Code Review <gerrit@ci3.zzzcomputing.com> | 2022-06-16 02:30:04 +0000 |
| commit | efd4e6dd8a73b956c860d796606f8e6ad652c292 (patch) | |
| tree | fc4e387f37e421d54068310eab5a01facb5f91c8 | |
| parent | 964c26feecc7607d6d3a66240c3f33f4ae9215d4 (diff) | |
| parent | 46c0fa56e904f6a00e56343302c4cb39955fa038 (diff) | |
| download | sqlalchemy-efd4e6dd8a73b956c860d796606f8e6ad652c292.tar.gz | |
Merge "implement literal stringification for arrays" into main
| -rw-r--r-- | doc/build/changelog/unreleased_20/8138.rst | 9 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/array.py | 51 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/sqltypes.py | 58 | ||||
| -rw-r--r-- | lib/sqlalchemy/testing/suite/test_types.py | 57 | ||||
| -rw-r--r-- | setup.cfg | 2 | ||||
| -rw-r--r-- | test/dialect/postgresql/test_types.py | 20 | ||||
| -rw-r--r-- | test/requirements.py | 12 | ||||
| -rw-r--r-- | test/sql/test_types.py | 41 |
8 files changed, 207 insertions, 43 deletions
diff --git a/doc/build/changelog/unreleased_20/8138.rst b/doc/build/changelog/unreleased_20/8138.rst new file mode 100644 index 000000000..510e8f9ed --- /dev/null +++ b/doc/build/changelog/unreleased_20/8138.rst @@ -0,0 +1,9 @@ +.. change:: + :tags: usecase, postgresql + :tickets: 8138 + + Added literal type rendering for the :class:`_sqltypes.ARRAY` and + :class:`_postgresql.ARRAY` datatypes. The generic stringify will render + using brackets, e.g. ``[1, 2, 3]`` and the PostgreSQL specific will use the + ARRAY literal e.g. ``ARRAY[1, 2, 3]``. Multiple dimensions and quoting + are also taken into account. diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 3b5eaed30..515eb2d15 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -310,35 +310,6 @@ class ARRAY(sqltypes.ARRAY): def compare_values(self, x, y): return x == y - def _proc_array(self, arr, itemproc, dim, collection): - if dim is None: - arr = list(arr) - if ( - dim == 1 - or dim is None - and ( - # this has to be (list, tuple), or at least - # not hasattr('__iter__'), since Py3K strings - # etc. have __iter__ - not arr - or not isinstance(arr[0], (list, tuple)) - ) - ): - if itemproc: - return collection(itemproc(x) for x in arr) - else: - return collection(arr) - else: - return collection( - self._proc_array( - x, - itemproc, - dim - 1 if dim is not None else None, - collection, - ) - for x in arr - ) - @util.memoized_property def _against_native_enum(self): return ( @@ -346,6 +317,24 @@ class ARRAY(sqltypes.ARRAY): and self.item_type.native_enum ) + def literal_processor(self, dialect): + item_proc = self.item_type.dialect_impl(dialect).literal_processor( + dialect + ) + if item_proc is None: + return None + + def to_str(elements): + return f"ARRAY[{', '.join(elements)}]" + + def process(value): + inner = self._apply_item_processor( + value, item_proc, self.dimensions, to_str + ) + return inner + + return process + def bind_processor(self, dialect): item_proc = self.item_type.dialect_impl(dialect).bind_processor( dialect @@ -355,7 +344,7 @@ class ARRAY(sqltypes.ARRAY): if value is None: return value else: - return self._proc_array( + return self._apply_item_processor( value, item_proc, self.dimensions, list ) @@ -370,7 +359,7 @@ class ARRAY(sqltypes.ARRAY): if value is None: return value else: - return self._proc_array( + return self._apply_item_processor( value, item_proc, self.dimensions, diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py index 32f0813f5..b4b444f23 100644 --- a/lib/sqlalchemy/sql/sqltypes.py +++ b/lib/sqlalchemy/sql/sqltypes.py @@ -2964,6 +2964,64 @@ class ARRAY( if isinstance(self.item_type, SchemaEventTarget): self.item_type._set_parent_with_dispatch(parent) + def literal_processor(self, dialect): + item_proc = self.item_type.dialect_impl(dialect).literal_processor( + dialect + ) + if item_proc is None: + return None + + def to_str(elements): + return f"[{', '.join(elements)}]" + + def process(value): + inner = self._apply_item_processor( + value, item_proc, self.dimensions, to_str + ) + return inner + + return process + + def _apply_item_processor(self, arr, itemproc, dim, collection_callable): + """Helper method that can be used by bind_processor(), + literal_processor(), etc. to apply an item processor to elements of + an array value, taking into account the 'dimensions' for this + array type. + + See the Postgresql ARRAY datatype for usage examples. + + .. versionadded:: 2.0 + + """ + + if dim is None: + arr = list(arr) + if ( + dim == 1 + or dim is None + and ( + # this has to be (list, tuple), or at least + # not hasattr('__iter__'), since Py3K strings + # etc. have __iter__ + not arr + or not isinstance(arr[0], (list, tuple)) + ) + ): + if itemproc: + return collection_callable(itemproc(x) for x in arr) + else: + return collection_callable(arr) + else: + return collection_callable( + self._apply_item_processor( + x, + itemproc, + dim - 1 if dim is not None else None, + collection_callable, + ) + for x in arr + ) + class TupleType(TypeEngine[Tuple[Any, ...]]): """represent the composite type of a Tuple.""" diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py index 391379956..9461298b9 100644 --- a/lib/sqlalchemy/testing/suite/test_types.py +++ b/lib/sqlalchemy/testing/suite/test_types.py @@ -17,6 +17,7 @@ from ..config import requirements from ..schema import Column from ..schema import Table from ... import and_ +from ... import ARRAY from ... import BigInteger from ... import bindparam from ... import Boolean @@ -222,6 +223,61 @@ class UnicodeTextTest(_UnicodeFixture, fixtures.TablesTest): self._test_null_strings(connection) +class ArrayTest(_LiteralRoundTripFixture, fixtures.TablesTest): + """Add ARRAY test suite, #8138. + + This only works on PostgreSQL right now. + + """ + + __requires__ = ("array_type",) + __backend__ = True + + @classmethod + def define_tables(cls, metadata): + Table( + "array_table", + metadata, + Column( + "id", Integer, primary_key=True, test_needs_autoincrement=True + ), + Column("single_dim", ARRAY(Integer)), + Column("multi_dim", ARRAY(String, dimensions=2)), + ) + + def test_array_roundtrip(self, connection): + array_table = self.tables.array_table + + connection.execute( + array_table.insert(), + { + "id": 1, + "single_dim": [1, 2, 3], + "multi_dim": [["one", "two"], ["thr'ee", "réve🐍 illé"]], + }, + ) + row = connection.execute( + select(array_table.c.single_dim, array_table.c.multi_dim) + ).first() + eq_(row, ([1, 2, 3], [["one", "two"], ["thr'ee", "réve🐍 illé"]])) + + def test_literal_simple(self, literal_round_trip): + literal_round_trip( + ARRAY(Integer), + ([1, 2, 3],), + ([1, 2, 3],), + support_whereclause=False, + ) + + def test_literal_complex(self, literal_round_trip): + literal_round_trip( + ARRAY(String, dimensions=2), + ([["one", "two"], ["thr'ee", "réve🐍 illé"]],), + ([["one", "two"], ["thr'ee", "réve🐍 illé"]],), + support_whereclause=False, + ) + + class BinaryTest(_LiteralRoundTripFixture, fixtures.TablesTest): __requires__ = ("binary_literals",) __backend__ = True @@ -1779,6 +1835,7 @@ class NativeUUIDTest(UuidTest): __all__ = ( + "ArrayTest", "BinaryTest", "UnicodeVarcharTest", "UnicodeTextTest", @@ -58,7 +58,7 @@ oracle = oracle_oracledb = oracledb>=1.0.1 postgresql = psycopg2>=2.7 -postgresql_pg8000 = pg8000>=1.16.6,!=1.29.0 +postgresql_pg8000 = pg8000>=1.29.1 postgresql_asyncpg = %(asyncio)s asyncpg diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 266263d5f..fd4b91db1 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -19,6 +19,7 @@ from sqlalchemy import Float from sqlalchemy import func from sqlalchemy import inspect from sqlalchemy import Integer +from sqlalchemy import literal from sqlalchemy import MetaData from sqlalchemy import null from sqlalchemy import Numeric @@ -52,6 +53,7 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import bindparam from sqlalchemy.sql import operators from sqlalchemy.sql import sqltypes +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import fixtures from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import assert_raises_message @@ -64,6 +66,7 @@ from sqlalchemy.testing.assertsql import RegexSQL from sqlalchemy.testing.schema import pep435_enum from sqlalchemy.testing.suite import test_types as suite from sqlalchemy.testing.util import round_decimal +from sqlalchemy.types import UserDefinedType class FloatCoercionTest(fixtures.TablesTest, AssertsExecutionResults): @@ -1230,6 +1233,23 @@ class ArrayTest(AssertsCompiledSQL, fixtures.TestBase): render_postcompile=True, ) + def test_array_literal_render_no_inner_render(self): + class MyType(UserDefinedType): + cache_ok = True + + def get_col_spec(self, **kw): + return "MYTYPE" + + with expect_raises_message( + NotImplementedError, + r"Don't know how to literal-quote value \[1, 2, 3\]", + ): + self.assert_compile( + select(literal([1, 2, 3], ARRAY(MyType()))), + "nothing", + literal_binds=True, + ) + def test_array_in_str_psycopg2_cast(self): expr = column("x", postgresql.ARRAY(String(15))).in_( [["one", "two"], ["three", "four"]] diff --git a/test/requirements.py b/test/requirements.py index 2d0876158..bea861a83 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -969,12 +969,7 @@ class DefaultRequirements(SuiteRequirements): @property def array_type(self): - return only_on( - [ - lambda config: against(config, "postgresql") - and not against(config, "+pg8000") - ] - ) + return only_on([lambda config: against(config, "postgresql")]) @property def json_type(self): @@ -1356,10 +1351,7 @@ class DefaultRequirements(SuiteRequirements): @property def postgresql_jsonb(self): - return only_on("postgresql >= 9.4") + skip_if( - lambda config: config.db.dialect.driver == "pg8000" - and config.db.dialect._dbapi_version <= (1, 10, 1) - ) + return only_on("postgresql >= 9.4") @property def native_hstore(self): diff --git a/test/sql/test_types.py b/test/sql/test_types.py index ef3915726..04aa4e000 100644 --- a/test/sql/test_types.py +++ b/test/sql/test_types.py @@ -93,6 +93,7 @@ from sqlalchemy.testing.schema import Column from sqlalchemy.testing.schema import pep435_enum from sqlalchemy.testing.schema import Table from sqlalchemy.testing.util import picklers +from sqlalchemy.types import UserDefinedType def _all_dialect_modules(): @@ -2904,7 +2905,7 @@ class JSONTest(fixtures.TestBase): eq_(bindproc(expr.right.value), "'five'") -class ArrayTest(fixtures.TestBase): +class ArrayTest(AssertsCompiledSQL, fixtures.TestBase): def _myarray_fixture(self): class MyArray(ARRAY): pass @@ -2957,6 +2958,44 @@ class ArrayTest(fixtures.TestBase): assert isinstance(arrtable.c.intarr[1:3].type, MyArray) assert isinstance(arrtable.c.strarr[1:3].type, MyArray) + def test_array_literal_simple(self): + self.assert_compile( + select(literal([1, 2, 3], ARRAY(Integer))), + "SELECT [1, 2, 3] AS anon_1", + literal_binds=True, + dialect="default", + ) + + def test_array_literal_complex(self): + self.assert_compile( + select( + literal( + [["one", "two"], ["thr'ee", "réve🐍 illé"]], + ARRAY(String, dimensions=2), + ) + ), + "SELECT [['one', 'two'], ['thr''ee', 'réve🐍 illé']] AS anon_1", + literal_binds=True, + dialect="default", + ) + + def test_array_literal_render_no_inner_render(self): + class MyType(UserDefinedType): + cache_ok = True + + def get_col_spec(self, **kw): + return "MYTYPE" + + with expect_raises_message( + NotImplementedError, + r"Don't know how to literal-quote value \[1, 2, 3\]", + ): + self.assert_compile( + select(literal([1, 2, 3], ARRAY(MyType()))), + "nothing", + literal_binds=True, + ) + MyCustomType = MyTypeDec = None |
