diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-11-22 14:28:26 -0500 |
|---|---|---|
| committer | Mike Bayer <mike_mp@zzzcomputing.com> | 2021-11-23 16:52:55 -0500 |
| commit | 939de240d31a5441ad7380738d410a976d4ecc3a (patch) | |
| tree | e5261a905636fa473760b1e81894453112bbaa66 /lib/sqlalchemy/dialects/postgresql | |
| parent | d3a4e96196cd47858de072ae589c6554088edc24 (diff) | |
| download | sqlalchemy-939de240d31a5441ad7380738d410a976d4ecc3a.tar.gz | |
propose emulated setinputsizes embedded in the compiler
Add a new system so that PostgreSQL and other dialects have a
reliable way to add casts to bound parameters in SQL statements,
replacing previous use of setinputsizes() for PG dialects.
rationale:
1. psycopg3 will be using the same SQLAlchemy-side "setinputsizes"
as asyncpg, so we will be seeing a lot more of this
2. the full rendering that SQLAlchemy's compilation is performing
is in the engine log as well as error messages. Without this,
we introduce three levels of SQL rendering, the compiler, the
hidden "setinputsizes" in SQLAlchemy, and then whatever the DBAPI
driver does. With this new approach, users reporting bugs etc.
will be less confused that there are as many as two separate
layers of "hidden rendering"; SQLAlchemy's rendering is again
fully transparent
3. calling upon a setinputsizes() method for every statement execution
is expensive. this way, the work is done behind the caching layer
4. for "fast insertmany()", I also want there to be a fast approach
towards setinputsizes. As it was, we were going to be taking
a SQL INSERT with thousands of bound parameter placeholders and
running a whole second pass on it to apply typecasts. this way,
we will at least be able to build the SQL string once without a huge
second pass over the whole string
5. psycopg2 can use this same system for its ARRAY casts
6. the general need for PostgreSQL to have lots of type casts
is now mostly in the base PostgreSQL dialect and works independently
of a DBAPI being present. dependence on DBAPI symbols that aren't
complete / consistent / hashable is removed
I was originally going to try to build this into bind_expression(),
but it was revealed this worked poorly with custom bind_expression()
as well as empty sets. the current impl also doesn't need to
run a second expression pass over the POSTCOMPILE sections, which
came out better than I originally thought it would.
Change-Id: I363e6d593d059add7bcc6d1f6c3f91dd2e683c0c
Diffstat (limited to 'lib/sqlalchemy/dialects/postgresql')
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/array.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/asyncpg.py | 152 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/base.py | 30 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/pg8000.py | 95 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/psycopg2.py | 4 |
5 files changed, 99 insertions, 185 deletions
diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index 0cb574dac..ff590c1b0 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -330,9 +330,6 @@ class ARRAY(sqltypes.ARRAY): and self.item_type.native_enum ) - def bind_expression(self, bindvalue): - return bindvalue - def bind_processor(self, dialect): item_proc = self.item_type.dialect_impl(dialect).bind_processor( dialect diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index fe1f9fd5a..4ac0971e5 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -134,32 +134,28 @@ except ImportError: _python_UUID = None +class AsyncpgString(sqltypes.String): + render_bind_cast = True + + class AsyncpgTime(sqltypes.Time): - def get_dbapi_type(self, dbapi): - return dbapi.TIME + render_bind_cast = True class AsyncpgDate(sqltypes.Date): - def get_dbapi_type(self, dbapi): - return dbapi.DATE + render_bind_cast = True class AsyncpgDateTime(sqltypes.DateTime): - def get_dbapi_type(self, dbapi): - if self.timezone: - return dbapi.TIMESTAMP_W_TZ - else: - return dbapi.TIMESTAMP + render_bind_cast = True class AsyncpgBoolean(sqltypes.Boolean): - def get_dbapi_type(self, dbapi): - return dbapi.BOOLEAN + render_bind_cast = True class AsyncPgInterval(INTERVAL): - def get_dbapi_type(self, dbapi): - return dbapi.INTERVAL + render_bind_cast = True @classmethod def adapt_emulated_to_native(cls, interval, **kw): @@ -168,49 +164,45 @@ class AsyncPgInterval(INTERVAL): class AsyncPgEnum(ENUM): - def get_dbapi_type(self, dbapi): - return dbapi.ENUM + render_bind_cast = True class AsyncpgInteger(sqltypes.Integer): - def get_dbapi_type(self, dbapi): - return dbapi.INTEGER + render_bind_cast = True class AsyncpgBigInteger(sqltypes.BigInteger): - def get_dbapi_type(self, dbapi): - return dbapi.BIGINTEGER + render_bind_cast = True class AsyncpgJSON(json.JSON): - def get_dbapi_type(self, dbapi): - return dbapi.JSON + render_bind_cast = True def result_processor(self, dialect, coltype): return None class AsyncpgJSONB(json.JSONB): - def get_dbapi_type(self, dbapi): - return dbapi.JSONB + render_bind_cast = True def result_processor(self, dialect, coltype): return None class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType): - def get_dbapi_type(self, dbapi): - raise NotImplementedError("should not be here") + pass class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType): - def get_dbapi_type(self, dbapi): - return dbapi.INTEGER + __visit_name__ = "json_int_index" + + render_bind_cast = True class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): - def get_dbapi_type(self, dbapi): - return dbapi.STRING + __visit_name__ = "json_str_index" + + render_bind_cast = True class AsyncpgJSONPathType(json.JSONPathType): @@ -224,8 +216,7 @@ class AsyncpgJSONPathType(json.JSONPathType): class AsyncpgUUID(UUID): - def get_dbapi_type(self, dbapi): - return dbapi.UUID + render_bind_cast = True def bind_processor(self, dialect): if not self.as_uuid and dialect.use_native_uuid: @@ -249,8 +240,7 @@ class AsyncpgUUID(UUID): class AsyncpgNumeric(sqltypes.Numeric): - def get_dbapi_type(self, dbapi): - return dbapi.NUMBER + render_bind_cast = True def bind_processor(self, dialect): return None @@ -281,18 +271,16 @@ class AsyncpgNumeric(sqltypes.Numeric): class AsyncpgFloat(AsyncpgNumeric): - def get_dbapi_type(self, dbapi): - return dbapi.FLOAT + __visit_name__ = "float" + render_bind_cast = True class AsyncpgREGCLASS(REGCLASS): - def get_dbapi_type(self, dbapi): - return dbapi.STRING + render_bind_cast = True class AsyncpgOID(OID): - def get_dbapi_type(self, dbapi): - return dbapi.INTEGER + render_bind_cast = True class PGExecutionContext_asyncpg(PGExecutionContext): @@ -317,11 +305,6 @@ class PGExecutionContext_asyncpg(PGExecutionContext): if not self.compiled: return - # we have to exclude ENUM because "enum" not really a "type" - # we can cast to, it has to be the name of the type itself. - # for now we just omit it from casting - self.exclude_set_input_sizes = {AsyncAdapt_asyncpg_dbapi.ENUM} - def create_server_side_cursor(self): return self._dbapi_connection.cursor(server_side=True) @@ -367,15 +350,7 @@ class AsyncAdapt_asyncpg_cursor: self._adapt_connection._handle_exception(error) def _parameter_placeholders(self, params): - if not self._inputsizes: - return tuple("$%d" % idx for idx, _ in enumerate(params, 1)) - else: - return tuple( - "$%d::%s" % (idx, typ) if typ else "$%d" % idx - for idx, typ in enumerate( - (_pg_types.get(typ) for typ in self._inputsizes), 1 - ) - ) + return tuple(f"${idx:d}" for idx, _ in enumerate(params, 1)) async def _prepare_and_execute(self, operation, parameters): adapt_connection = self._adapt_connection @@ -464,7 +439,7 @@ class AsyncAdapt_asyncpg_cursor: ) def setinputsizes(self, *inputsizes): - self._inputsizes = inputsizes + raise NotImplementedError() def __iter__(self): while self._rows: @@ -798,6 +773,12 @@ class AsyncAdapt_asyncpg_dbapi: "all prepared caches in response to this exception)", ) + # pep-249 datatype placeholders. As of SQLAlchemy 2.0 these aren't + # used, however the test suite looks for these in a few cases. + STRING = util.symbol("STRING") + NUMBER = util.symbol("NUMBER") + DATETIME = util.symbol("DATETIME") + @util.memoized_property def _asyncpg_error_translate(self): import asyncpg @@ -814,50 +795,6 @@ class AsyncAdapt_asyncpg_dbapi: def Binary(self, value): return value - STRING = util.symbol("STRING") - TIMESTAMP = util.symbol("TIMESTAMP") - TIMESTAMP_W_TZ = util.symbol("TIMESTAMP_W_TZ") - TIME = util.symbol("TIME") - DATE = util.symbol("DATE") - INTERVAL = util.symbol("INTERVAL") - NUMBER = util.symbol("NUMBER") - FLOAT = util.symbol("FLOAT") - BOOLEAN = util.symbol("BOOLEAN") - INTEGER = util.symbol("INTEGER") - BIGINTEGER = util.symbol("BIGINTEGER") - BYTES = util.symbol("BYTES") - DECIMAL = util.symbol("DECIMAL") - JSON = util.symbol("JSON") - JSONB = util.symbol("JSONB") - ENUM = util.symbol("ENUM") - UUID = util.symbol("UUID") - BYTEA = util.symbol("BYTEA") - - DATETIME = TIMESTAMP - BINARY = BYTEA - - -_pg_types = { - AsyncAdapt_asyncpg_dbapi.STRING: "varchar", - AsyncAdapt_asyncpg_dbapi.TIMESTAMP: "timestamp", - AsyncAdapt_asyncpg_dbapi.TIMESTAMP_W_TZ: "timestamp with time zone", - AsyncAdapt_asyncpg_dbapi.DATE: "date", - AsyncAdapt_asyncpg_dbapi.TIME: "time", - AsyncAdapt_asyncpg_dbapi.INTERVAL: "interval", - AsyncAdapt_asyncpg_dbapi.NUMBER: "numeric", - AsyncAdapt_asyncpg_dbapi.FLOAT: "float", - AsyncAdapt_asyncpg_dbapi.BOOLEAN: "bool", - AsyncAdapt_asyncpg_dbapi.INTEGER: "integer", - AsyncAdapt_asyncpg_dbapi.BIGINTEGER: "bigint", - AsyncAdapt_asyncpg_dbapi.BYTES: "bytes", - AsyncAdapt_asyncpg_dbapi.DECIMAL: "decimal", - AsyncAdapt_asyncpg_dbapi.JSON: "json", - AsyncAdapt_asyncpg_dbapi.JSONB: "jsonb", - AsyncAdapt_asyncpg_dbapi.ENUM: "enum", - AsyncAdapt_asyncpg_dbapi.UUID: "uuid", - AsyncAdapt_asyncpg_dbapi.BYTEA: "bytea", -} - class PGDialect_asyncpg(PGDialect): driver = "asyncpg" @@ -865,19 +802,20 @@ class PGDialect_asyncpg(PGDialect): supports_server_side_cursors = True + render_bind_cast = True + default_paramstyle = "format" supports_sane_multi_rowcount = False execution_ctx_cls = PGExecutionContext_asyncpg statement_compiler = PGCompiler_asyncpg preparer = PGIdentifierPreparer_asyncpg - use_setinputsizes = True - use_native_uuid = True colspecs = util.update_copy( PGDialect.colspecs, { + sqltypes.String: AsyncpgString, sqltypes.Time: AsyncpgTime, sqltypes.Date: AsyncpgDate, sqltypes.DateTime: AsyncpgDateTime, @@ -977,20 +915,6 @@ class PGDialect_asyncpg(PGDialect): e, self.dbapi.InterfaceError ) and "connection is closed" in str(e) - def do_set_input_sizes(self, cursor, list_of_tuples, context): - if self.positional: - cursor.setinputsizes( - *[dbtype for key, dbtype, sqltype in list_of_tuples] - ) - else: - cursor.setinputsizes( - **{ - key: dbtype - for key, dbtype, sqltype in list_of_tuples - if dbtype - } - ) - async def setup_asyncpg_json_codec(self, conn): """set up JSON codec for asyncpg. diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 583d9c263..800b289fb 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -1388,6 +1388,7 @@ from ... import sql from ... import util from ...engine import characteristics from ...engine import default +from ...engine import interfaces from ...engine import reflection from ...sql import coercions from ...sql import compiler @@ -2041,16 +2042,6 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): self.drop(bind=bind, checkfirst=checkfirst) -class _ColonCast(elements.CompilerColumnElement): - __visit_name__ = "colon_cast" - __slots__ = ("type", "clause", "typeclause") - - def __init__(self, expression, type_): - self.type = type_ - self.clause = expression - self.typeclause = elements.TypeClause(type_) - - colspecs = { sqltypes.ARRAY: _array.ARRAY, sqltypes.Interval: INTERVAL, @@ -2106,11 +2097,12 @@ ischema_names = { class PGCompiler(compiler.SQLCompiler): - def visit_colon_cast(self, element, **kw): - return "%s::%s" % ( - element.clause._compiler_dispatch(self, **kw), - element.typeclause._compiler_dispatch(self, **kw), - ) + def render_bind_cast(self, type_, dbapi_type, sqltext): + return f"""{sqltext}::{ + self.dialect.type_compiler.process( + dbapi_type, identifier_preparer=self.preparer + ) + }""" def visit_array(self, element, **kw): return "ARRAY[%s]" % self.visit_clauselist(element, **kw) @@ -2854,6 +2846,12 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_TSTZRANGE(self, type_, **kw): return "TSTZRANGE" + def visit_json_int_index(self, type_, **kw): + return "INT" + + def visit_json_str_index(self, type_, **kw): + return "TEXT" + def visit_datetime(self, type_, **kw): return self.visit_TIMESTAMP(type_, **kw) @@ -3121,6 +3119,8 @@ class PGDialect(default.DefaultDialect): max_identifier_length = 63 supports_sane_rowcount = True + bind_typing = interfaces.BindTyping.RENDER_CASTS + supports_native_enum = True supports_native_boolean = True supports_smallserial = True diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index 324007e7e..e849d0499 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -94,7 +94,6 @@ import re from uuid import UUID as _python_UUID from .array import ARRAY as PGARRAY -from .base import _ColonCast from .base import _DECIMAL_TYPES from .base import _FLOAT_TYPES from .base import _INT_TYPES @@ -115,7 +114,13 @@ from ... import util from ...sql.elements import quoted_name +class _PGString(sqltypes.String): + render_bind_cast = True + + class _PGNumeric(sqltypes.Numeric): + render_bind_cast = True + def result_processor(self, dialect, coltype): if self.asdecimal: if coltype in _FLOAT_TYPES: @@ -141,26 +146,29 @@ class _PGNumeric(sqltypes.Numeric): ) +class _PGFloat(_PGNumeric): + __visit_name__ = "float" + render_bind_cast = True + + class _PGNumericNoBind(_PGNumeric): def bind_processor(self, dialect): return None class _PGJSON(JSON): + render_bind_cast = True + def result_processor(self, dialect, coltype): return None - def get_dbapi_type(self, dbapi): - return dbapi.JSON - class _PGJSONB(JSONB): + render_bind_cast = True + def result_processor(self, dialect, coltype): return None - def get_dbapi_type(self, dbapi): - return dbapi.JSONB - class _PGJSONIndexType(sqltypes.JSON.JSONIndexType): def get_dbapi_type(self, dbapi): @@ -168,21 +176,26 @@ class _PGJSONIndexType(sqltypes.JSON.JSONIndexType): class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType): - def get_dbapi_type(self, dbapi): - return dbapi.INTEGER + __visit_name__ = "json_int_index" + + render_bind_cast = True class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType): - def get_dbapi_type(self, dbapi): - return dbapi.STRING + __visit_name__ = "json_str_index" + + render_bind_cast = True class _PGJSONPathType(JSONPathType): - def get_dbapi_type(self, dbapi): - return 1009 + pass + + # DBAPI type 1009 class _PGUUID(UUID): + render_bind_cast = True + def bind_processor(self, dialect): if not self.as_uuid: @@ -210,6 +223,8 @@ class _PGEnum(ENUM): class _PGInterval(INTERVAL): + render_bind_cast = True + def get_dbapi_type(self, dbapi): return dbapi.INTERVAL @@ -219,48 +234,39 @@ class _PGInterval(INTERVAL): class _PGTimeStamp(sqltypes.DateTime): - def get_dbapi_type(self, dbapi): - if self.timezone: - # TIMESTAMPTZOID - return 1184 - else: - # TIMESTAMPOID - return 1114 + render_bind_cast = True + + +class _PGDate(sqltypes.Date): + render_bind_cast = True class _PGTime(sqltypes.Time): - def get_dbapi_type(self, dbapi): - return dbapi.TIME + render_bind_cast = True class _PGInteger(sqltypes.Integer): - def get_dbapi_type(self, dbapi): - return dbapi.INTEGER + render_bind_cast = True class _PGSmallInteger(sqltypes.SmallInteger): - def get_dbapi_type(self, dbapi): - return dbapi.INTEGER + render_bind_cast = True class _PGNullType(sqltypes.NullType): - def get_dbapi_type(self, dbapi): - return dbapi.NULLTYPE + pass class _PGBigInteger(sqltypes.BigInteger): - def get_dbapi_type(self, dbapi): - return dbapi.BIGINTEGER + render_bind_cast = True class _PGBoolean(sqltypes.Boolean): - def get_dbapi_type(self, dbapi): - return dbapi.BOOLEAN + render_bind_cast = True class _PGARRAY(PGARRAY): - def bind_expression(self, bindvalue): - return _ColonCast(bindvalue, self) + render_bind_cast = True _server_side_id = util.counter() @@ -362,7 +368,7 @@ class PGDialect_pg8000(PGDialect): preparer = PGIdentifierPreparer_pg8000 supports_server_side_cursors = True - use_setinputsizes = True + render_bind_cast = True # reversed as of pg8000 1.16.6. 1.16.5 and lower # are no longer compatible @@ -372,8 +378,9 @@ class PGDialect_pg8000(PGDialect): colspecs = util.update_copy( PGDialect.colspecs, { + sqltypes.String: _PGString, sqltypes.Numeric: _PGNumericNoBind, - sqltypes.Float: _PGNumeric, + sqltypes.Float: _PGFloat, sqltypes.JSON: _PGJSON, sqltypes.Boolean: _PGBoolean, sqltypes.NullType: _PGNullType, @@ -386,6 +393,8 @@ class PGDialect_pg8000(PGDialect): sqltypes.Interval: _PGInterval, INTERVAL: _PGInterval, sqltypes.DateTime: _PGTimeStamp, + sqltypes.DateTime: _PGTimeStamp, + sqltypes.Date: _PGDate, sqltypes.Time: _PGTime, sqltypes.Integer: _PGInteger, sqltypes.SmallInteger: _PGSmallInteger, @@ -517,20 +526,6 @@ class PGDialect_pg8000(PGDialect): cursor.execute("COMMIT") cursor.close() - def do_set_input_sizes(self, cursor, list_of_tuples, context): - if self.positional: - cursor.setinputsizes( - *[dbtype for key, dbtype, sqltype in list_of_tuples] - ) - else: - cursor.setinputsizes( - **{ - key: dbtype - for key, dbtype, sqltype in list_of_tuples - if dbtype - } - ) - def do_begin_twophase(self, connection, xid): connection.connection.tpc_begin((0, xid, "")) diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index f62830a0d..19c01d208 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -449,7 +449,6 @@ import re from uuid import UUID as _python_UUID from .array import ARRAY as PGARRAY -from .base import _ColonCast from .base import _DECIMAL_TYPES from .base import _FLOAT_TYPES from .base import _INT_TYPES @@ -516,8 +515,7 @@ class _PGHStore(HSTORE): class _PGARRAY(PGARRAY): - def bind_expression(self, bindvalue): - return _ColonCast(bindvalue, self) + render_bind_cast = True class _PGJSON(JSON): |
