summaryrefslogtreecommitdiff
path: root/test/dialect
diff options
context:
space:
mode:
authorFederico Caselli <cfederico87@gmail.com>2021-09-14 23:38:00 +0200
committerMike Bayer <mike_mp@zzzcomputing.com>2021-11-26 10:14:44 -0500
commit5eb407f84bdabdbcd68975dbf76dc4c0809d7373 (patch)
tree0d37ab4b9c28d8a0fa6cefdcc1933d52ffd9a599 /test/dialect
parent8ddb3ef165d0c2d6d7167bb861bb349e68b5e8df (diff)
downloadsqlalchemy-5eb407f84bdabdbcd68975dbf76dc4c0809d7373.tar.gz
Added support for ``psycopg`` dialect.
Both sync and async versions are supported. Fixes: #6842 Change-Id: I57751c5028acebfc6f9c43572562405453a2f2a4
Diffstat (limited to 'test/dialect')
-rw-r--r--test/dialect/postgresql/test_dialect.py134
-rw-r--r--test/dialect/postgresql/test_query.py161
-rw-r--r--test/dialect/postgresql/test_types.py30
3 files changed, 278 insertions, 47 deletions
diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py
index 57682686c..02d7ad483 100644
--- a/test/dialect/postgresql/test_dialect.py
+++ b/test/dialect/postgresql/test_dialect.py
@@ -8,6 +8,7 @@ from sqlalchemy import BigInteger
from sqlalchemy import bindparam
from sqlalchemy import cast
from sqlalchemy import Column
+from sqlalchemy import create_engine
from sqlalchemy import DateTime
from sqlalchemy import DDL
from sqlalchemy import event
@@ -30,6 +31,9 @@ from sqlalchemy import text
from sqlalchemy import TypeDecorator
from sqlalchemy import util
from sqlalchemy.dialects.postgresql import base as postgresql
+from sqlalchemy.dialects.postgresql import HSTORE
+from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.dialects.postgresql import psycopg as psycopg_dialect
from sqlalchemy.dialects.postgresql import psycopg2 as psycopg2_dialect
from sqlalchemy.dialects.postgresql.psycopg2 import EXECUTEMANY_BATCH
from sqlalchemy.dialects.postgresql.psycopg2 import EXECUTEMANY_PLAIN
@@ -269,10 +273,12 @@ class PGCodeTest(fixtures.TestBase):
if testing.against("postgresql+pg8000"):
# TODO: is there another way we're supposed to see this?
eq_(errmsg.orig.args[0]["C"], "23505")
- else:
+ elif not testing.against("postgresql+psycopg"):
eq_(errmsg.orig.pgcode, "23505")
- if testing.against("postgresql+asyncpg"):
+ if testing.against("postgresql+asyncpg") or testing.against(
+ "postgresql+psycopg"
+ ):
eq_(errmsg.orig.sqlstate, "23505")
@@ -858,6 +864,13 @@ class MiscBackendTest(
".".join(str(x) for x in v)
)
+ @testing.only_on("postgresql+psycopg")
+ def test_psycopg_version(self):
+ v = testing.db.dialect.psycopg_version
+ assert testing.db.dialect.dbapi.__version__.startswith(
+ ".".join(str(x) for x in v)
+ )
+
@testing.combinations(
((8, 1), False, False),
((8, 1), None, False),
@@ -902,6 +915,7 @@ class MiscBackendTest(
with testing.db.connect().execution_options(
isolation_level="SERIALIZABLE"
) as conn:
+
dbapi_conn = conn.connection.dbapi_connection
is_false(dbapi_conn.autocommit)
@@ -1069,25 +1083,30 @@ class MiscBackendTest(
dbapi_conn.rollback()
eq_(val, "off")
- @testing.requires.psycopg2_compatibility
- def test_psycopg2_non_standard_err(self):
+ @testing.requires.psycopg_compatibility
+ def test_psycopg_non_standard_err(self):
# note that psycopg2 is sometimes called psycopg2cffi
# depending on platform
- psycopg2 = testing.db.dialect.dbapi
- TransactionRollbackError = __import__(
- "%s.extensions" % psycopg2.__name__
- ).extensions.TransactionRollbackError
+ psycopg = testing.db.dialect.dbapi
+ if psycopg.__version__.startswith("3"):
+ TransactionRollbackError = __import__(
+ "%s.errors" % psycopg.__name__
+ ).errors.TransactionRollback
+ else:
+ TransactionRollbackError = __import__(
+ "%s.extensions" % psycopg.__name__
+ ).extensions.TransactionRollbackError
exception = exc.DBAPIError.instance(
"some statement",
{},
TransactionRollbackError("foo"),
- psycopg2.Error,
+ psycopg.Error,
)
assert isinstance(exception, exc.OperationalError)
@testing.requires.no_coverage
- @testing.requires.psycopg2_compatibility
+ @testing.requires.psycopg_compatibility
def test_notice_logging(self):
log = logging.getLogger("sqlalchemy.dialects.postgresql")
buf = logging.handlers.BufferingHandler(100)
@@ -1115,14 +1134,14 @@ $$ LANGUAGE plpgsql;
finally:
log.removeHandler(buf)
log.setLevel(lev)
- msgs = " ".join(b.msg for b in buf.buffer)
+ msgs = " ".join(b.getMessage() for b in buf.buffer)
eq_regex(
msgs,
- "NOTICE: notice: hi there(\nCONTEXT: .*?)? "
- "NOTICE: notice: another note(\nCONTEXT: .*?)?",
+ "NOTICE: [ ]?notice: hi there(\nCONTEXT: .*?)? "
+ "NOTICE: [ ]?notice: another note(\nCONTEXT: .*?)?",
)
- @testing.requires.psycopg2_or_pg8000_compatibility
+ @testing.requires.psycopg_or_pg8000_compatibility
@engines.close_open_connections
def test_client_encoding(self):
c = testing.db.connect()
@@ -1143,7 +1162,7 @@ $$ LANGUAGE plpgsql;
new_encoding = c.exec_driver_sql("show client_encoding").fetchone()[0]
eq_(new_encoding, test_encoding)
- @testing.requires.psycopg2_or_pg8000_compatibility
+ @testing.requires.psycopg_or_pg8000_compatibility
@engines.close_open_connections
def test_autocommit_isolation_level(self):
c = testing.db.connect().execution_options(
@@ -1302,7 +1321,7 @@ $$ LANGUAGE plpgsql;
assert result == [(1, "user", "lala")]
connection.execute(text("DROP TABLE speedy_users"))
- @testing.requires.psycopg2_or_pg8000_compatibility
+ @testing.requires.psycopg_or_pg8000_compatibility
def test_numeric_raise(self, connection):
stmt = text("select cast('hi' as char) as hi").columns(hi=Numeric)
assert_raises(exc.InvalidRequestError, connection.execute, stmt)
@@ -1364,9 +1383,90 @@ $$ LANGUAGE plpgsql;
)
@testing.requires.psycopg2_compatibility
- def test_initial_transaction_state(self):
+ def test_initial_transaction_state_psycopg2(self):
from psycopg2.extensions import STATUS_IN_TRANSACTION
engine = engines.testing_engine()
with engine.connect() as conn:
ne_(conn.connection.status, STATUS_IN_TRANSACTION)
+
+ @testing.only_on("postgresql+psycopg")
+ def test_initial_transaction_state_psycopg(self):
+ from psycopg.pq import TransactionStatus
+
+ engine = engines.testing_engine()
+ with engine.connect() as conn:
+ ne_(
+ conn.connection.dbapi_connection.info.transaction_status,
+ TransactionStatus.INTRANS,
+ )
+
+
+class Psycopg3Test(fixtures.TestBase):
+ __only_on__ = ("postgresql+psycopg",)
+
+ def test_json_correctly_registered(self, testing_engine):
+ import json
+
+ def loads(value):
+ value = json.loads(value)
+ value["x"] = value["x"] + "_loads"
+ return value
+
+ def dumps(value):
+ value = dict(value)
+ value["x"] = "dumps_y"
+ return json.dumps(value)
+
+ engine = testing_engine(
+ options=dict(json_serializer=dumps, json_deserializer=loads)
+ )
+ engine2 = testing_engine(
+ options=dict(
+ json_serializer=json.dumps, json_deserializer=json.loads
+ )
+ )
+
+ s = select(cast({"key": "value", "x": "q"}, JSONB))
+ with engine.begin() as conn:
+ eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"})
+ with engine.begin() as conn:
+ eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"})
+ with engine2.begin() as conn:
+ eq_(conn.scalar(s), {"key": "value", "x": "q"})
+ with engine.begin() as conn:
+ eq_(conn.scalar(s), {"key": "value", "x": "dumps_y_loads"})
+
+ @testing.requires.hstore
+ def test_hstore_correctly_registered(self, testing_engine):
+ engine = testing_engine(options=dict(use_native_hstore=True))
+ engine2 = testing_engine(options=dict(use_native_hstore=False))
+
+ def rp(self, *a):
+ return lambda a: {"a": "b"}
+
+ with mock.patch.object(HSTORE, "result_processor", side_effect=rp):
+ s = select(cast({"key": "value", "x": "q"}, HSTORE))
+ with engine.begin() as conn:
+ eq_(conn.scalar(s), {"key": "value", "x": "q"})
+ with engine.begin() as conn:
+ eq_(conn.scalar(s), {"key": "value", "x": "q"})
+ with engine2.begin() as conn:
+ eq_(conn.scalar(s), {"a": "b"})
+ with engine.begin() as conn:
+ eq_(conn.scalar(s), {"key": "value", "x": "q"})
+
+ def test_get_dialect(self):
+ u = url.URL.create("postgresql://")
+ d = psycopg_dialect.PGDialect_psycopg.get_dialect_cls(u)
+ is_(d, psycopg_dialect.PGDialect_psycopg)
+ d = psycopg_dialect.PGDialect_psycopg.get_async_dialect_cls(u)
+ is_(d, psycopg_dialect.PGDialectAsync_psycopg)
+ d = psycopg_dialect.PGDialectAsync_psycopg.get_dialect_cls(u)
+ is_(d, psycopg_dialect.PGDialectAsync_psycopg)
+ d = psycopg_dialect.PGDialectAsync_psycopg.get_dialect_cls(u)
+ is_(d, psycopg_dialect.PGDialectAsync_psycopg)
+
+ def test_async_version(self):
+ e = create_engine("postgresql+psycopg_async://")
+ is_true(isinstance(e.dialect, psycopg_dialect.PGDialectAsync_psycopg))
diff --git a/test/dialect/postgresql/test_query.py b/test/dialect/postgresql/test_query.py
index b488b146c..fdce643f8 100644
--- a/test/dialect/postgresql/test_query.py
+++ b/test/dialect/postgresql/test_query.py
@@ -14,6 +14,7 @@ from sqlalchemy import Float
from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Integer
+from sqlalchemy import JSON
from sqlalchemy import literal
from sqlalchemy import literal_column
from sqlalchemy import MetaData
@@ -29,6 +30,7 @@ from sqlalchemy import true
from sqlalchemy import tuple_
from sqlalchemy.dialects import postgresql
from sqlalchemy.dialects.postgresql import JSONB
+from sqlalchemy.sql.expression import type_coerce
from sqlalchemy.testing import assert_raises
from sqlalchemy.testing import AssertsCompiledSQL
from sqlalchemy.testing import AssertsExecutionResults
@@ -40,6 +42,17 @@ from sqlalchemy.testing.assertsql import CursorSQL
from sqlalchemy.testing.assertsql import DialectSQL
+class FunctionTypingTest(fixtures.TestBase, AssertsExecutionResults):
+ __only_on__ = "postgresql"
+ __backend__ = True
+
+ def test_count_star(self, connection):
+ eq_(connection.scalar(func.count("*")), 1)
+
+ def test_count_int(self, connection):
+ eq_(connection.scalar(func.count(1)), 1)
+
+
class InsertTest(fixtures.TestBase, AssertsExecutionResults):
__only_on__ = "postgresql"
@@ -956,23 +969,42 @@ class MatchTest(fixtures.TablesTest, AssertsCompiledSQL):
],
)
+ def _strs_render_bind_casts(self, connection):
+
+ return (
+ connection.dialect._bind_typing_render_casts
+ and String().dialect_impl(connection.dialect).render_bind_cast
+ )
+
@testing.requires.pyformat_paramstyle
- def test_expression_pyformat(self):
+ def test_expression_pyformat(self, connection):
matchtable = self.tables.matchtable
- self.assert_compile(
- matchtable.c.title.match("somstr"),
- "matchtable.title @@ to_tsquery(%(title_1)s" ")",
- )
+
+ if self._strs_render_bind_casts(connection):
+ self.assert_compile(
+ matchtable.c.title.match("somstr"),
+ "matchtable.title @@ to_tsquery(%(title_1)s::VARCHAR(200))",
+ )
+ else:
+ self.assert_compile(
+ matchtable.c.title.match("somstr"),
+ "matchtable.title @@ to_tsquery(%(title_1)s)",
+ )
@testing.requires.format_paramstyle
- def test_expression_positional(self):
+ def test_expression_positional(self, connection):
matchtable = self.tables.matchtable
- self.assert_compile(
- matchtable.c.title.match("somstr"),
- # note we assume current tested DBAPIs use emulated setinputsizes
- # here, the cast is not strictly necessary
- "matchtable.title @@ to_tsquery(%s::VARCHAR(200))",
- )
+
+ if self._strs_render_bind_casts(connection):
+ self.assert_compile(
+ matchtable.c.title.match("somstr"),
+ "matchtable.title @@ to_tsquery(%s::VARCHAR(200))",
+ )
+ else:
+ self.assert_compile(
+ matchtable.c.title.match("somstr"),
+ "matchtable.title @@ to_tsquery(%s)",
+ )
def test_simple_match(self, connection):
matchtable = self.tables.matchtable
@@ -1551,17 +1583,106 @@ class TableValuedRoundTripTest(fixtures.TestBase):
[(14, 1), (41, 2), (7, 3), (54, 4), (9, 5), (49, 6)],
)
- @testing.only_on(
- "postgresql+psycopg2",
- "I cannot get this to run at all on other drivers, "
- "even selecting from a table",
+ @testing.combinations(
+ (
+ type_coerce,
+ testing.fails("fails on all drivers"),
+ ),
+ (
+ cast,
+ testing.fails("fails on all drivers"),
+ ),
+ (
+ None,
+ testing.fails_on_everything_except(
+ ["postgresql+psycopg2"],
+ "I cannot get this to run at all on other drivers, "
+ "even selecting from a table",
+ ),
+ ),
+ argnames="cast_fn",
)
- def test_render_derived_quoting(self, connection):
+ def test_render_derived_quoting_text(self, connection, cast_fn):
+
+ value = (
+ '[{"CaseSensitive":1,"the % value":"foo"}, '
+ '{"CaseSensitive":"2","the % value":"bar"}]'
+ )
+
+ if cast_fn:
+ value = cast_fn(value, JSON)
+
fn = (
- func.json_to_recordset( # noqa
- '[{"CaseSensitive":1,"the % value":"foo"}, '
- '{"CaseSensitive":"2","the % value":"bar"}]'
+ func.json_to_recordset(value)
+ .table_valued(
+ column("CaseSensitive", Integer), column("the % value", String)
)
+ .render_derived(with_types=True)
+ )
+
+ stmt = select(fn.c.CaseSensitive, fn.c["the % value"])
+
+ eq_(connection.execute(stmt).all(), [(1, "foo"), (2, "bar")])
+
+ @testing.combinations(
+ (
+ type_coerce,
+ testing.fails("fails on all drivers"),
+ ),
+ (
+ cast,
+ testing.fails("fails on all drivers"),
+ ),
+ (
+ None,
+ testing.fails("Fails on all drivers"),
+ ),
+ argnames="cast_fn",
+ )
+ def test_render_derived_quoting_text_to_json(self, connection, cast_fn):
+
+ value = (
+ '[{"CaseSensitive":1,"the % value":"foo"}, '
+ '{"CaseSensitive":"2","the % value":"bar"}]'
+ )
+
+ if cast_fn:
+ value = cast_fn(value, JSON)
+
+ # why wont this work?!?!?
+ # should be exactly json_to_recordset(to_json('string'::text))
+ #
+ fn = (
+ func.json_to_recordset(func.to_json(value))
+ .table_valued(
+ column("CaseSensitive", Integer), column("the % value", String)
+ )
+ .render_derived(with_types=True)
+ )
+
+ stmt = select(fn.c.CaseSensitive, fn.c["the % value"])
+
+ eq_(connection.execute(stmt).all(), [(1, "foo"), (2, "bar")])
+
+ @testing.combinations(
+ (type_coerce,),
+ (cast,),
+ (None, testing.fails("Fails on all PG backends")),
+ argnames="cast_fn",
+ )
+ def test_render_derived_quoting_straight_json(self, connection, cast_fn):
+ # these all work
+
+ value = [
+ {"CaseSensitive": 1, "the % value": "foo"},
+ {"CaseSensitive": "2", "the % value": "bar"},
+ ]
+
+ if cast_fn:
+ value = cast_fn(value, JSON)
+
+ fn = (
+ func.json_to_recordset(value) # noqa
.table_valued(
column("CaseSensitive", Integer), column("the % value", String)
)
diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py
index 4008881d2..5f8a41d1f 100644
--- a/test/dialect/postgresql/test_types.py
+++ b/test/dialect/postgresql/test_types.py
@@ -2360,7 +2360,7 @@ class ArrayEnum(fixtures.TestBase):
testing.combinations(
sqltypes.ARRAY,
postgresql.ARRAY,
- (_ArrayOfEnum, testing.only_on("postgresql+psycopg2")),
+ (_ArrayOfEnum, testing.requires.psycopg_compatibility),
argnames="array_cls",
)(fn)
)
@@ -3066,7 +3066,7 @@ class HStoreRoundTripTest(fixtures.TablesTest):
@testing.fixture
def non_native_hstore_connection(self, testing_engine):
- local_engine = testing.requires.psycopg2_native_hstore.enabled
+ local_engine = testing.requires.native_hstore.enabled
if local_engine:
engine = testing_engine(options=dict(use_native_hstore=False))
@@ -3096,14 +3096,14 @@ class HStoreRoundTripTest(fixtures.TablesTest):
)["1"]
eq_(connection.scalar(select(expr)), "3")
- @testing.requires.psycopg2_native_hstore
+ @testing.requires.native_hstore
def test_insert_native(self, connection):
self._test_insert(connection)
def test_insert_python(self, non_native_hstore_connection):
self._test_insert(non_native_hstore_connection)
- @testing.requires.psycopg2_native_hstore
+ @testing.requires.native_hstore
def test_criterion_native(self, connection):
self._fixture_data(connection)
self._test_criterion(connection)
@@ -3134,7 +3134,7 @@ class HStoreRoundTripTest(fixtures.TablesTest):
def test_fixed_round_trip_python(self, non_native_hstore_connection):
self._test_fixed_round_trip(non_native_hstore_connection)
- @testing.requires.psycopg2_native_hstore
+ @testing.requires.native_hstore
def test_fixed_round_trip_native(self, connection):
self._test_fixed_round_trip(connection)
@@ -3154,11 +3154,11 @@ class HStoreRoundTripTest(fixtures.TablesTest):
},
)
- @testing.requires.psycopg2_native_hstore
+ @testing.requires.native_hstore
def test_unicode_round_trip_python(self, non_native_hstore_connection):
self._test_unicode_round_trip(non_native_hstore_connection)
- @testing.requires.psycopg2_native_hstore
+ @testing.requires.native_hstore
def test_unicode_round_trip_native(self, connection):
self._test_unicode_round_trip(connection)
@@ -3167,7 +3167,7 @@ class HStoreRoundTripTest(fixtures.TablesTest):
):
self._test_escaped_quotes_round_trip(non_native_hstore_connection)
- @testing.requires.psycopg2_native_hstore
+ @testing.requires.native_hstore
def test_escaped_quotes_round_trip_native(self, connection):
self._test_escaped_quotes_round_trip(connection)
@@ -3356,7 +3356,7 @@ class _RangeTypeCompilation(AssertsCompiledSQL, fixtures.TestBase):
class _RangeTypeRoundTrip(fixtures.TablesTest):
- __requires__ = "range_types", "psycopg2_compatibility"
+ __requires__ = "range_types", "psycopg_compatibility"
__backend__ = True
def extras(self):
@@ -3364,8 +3364,18 @@ class _RangeTypeRoundTrip(fixtures.TablesTest):
# older psycopg2 versions.
if testing.against("postgresql+psycopg2cffi"):
from psycopg2cffi import extras
- else:
+ elif testing.against("postgresql+psycopg2"):
from psycopg2 import extras
+ elif testing.against("postgresql+psycopg"):
+ from psycopg.types.range import Range
+
+ class psycopg_extras:
+ def __getattr__(self, _):
+ return Range
+
+ extras = psycopg_extras()
+ else:
+ assert False, "Unknonw dialect"
return extras
@classmethod