summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2021-11-22 14:28:26 -0500
committerMike Bayer <mike_mp@zzzcomputing.com>2021-11-23 16:52:55 -0500
commit939de240d31a5441ad7380738d410a976d4ecc3a (patch)
treee5261a905636fa473760b1e81894453112bbaa66 /lib/sqlalchemy
parentd3a4e96196cd47858de072ae589c6554088edc24 (diff)
downloadsqlalchemy-939de240d31a5441ad7380738d410a976d4ecc3a.tar.gz
propose emulated setinputsizes embedded in the compiler
Add a new system so that PostgreSQL and other dialects have a reliable way to add casts to bound parameters in SQL statements, replacing previous use of setinputsizes() for PG dialects. rationale: 1. psycopg3 will be using the same SQLAlchemy-side "setinputsizes" as asyncpg, so we will be seeing a lot more of this 2. the full rendering that SQLAlchemy's compilation is performing is in the engine log as well as error messages. Without this, we introduce three levels of SQL rendering, the compiler, the hidden "setinputsizes" in SQLAlchemy, and then whatever the DBAPI driver does. With this new approach, users reporting bugs etc. will be less confused that there are as many as two separate layers of "hidden rendering"; SQLAlchemy's rendering is again fully transparent 3. calling upon a setinputsizes() method for every statement execution is expensive. this way, the work is done behind the caching layer 4. for "fast insertmany()", I also want there to be a fast approach towards setinputsizes. As it was, we were going to be taking a SQL INSERT with thousands of bound parameter placeholders and running a whole second pass on it to apply typecasts. this way, we will at least be able to build the SQL string once without a huge second pass over the whole string 5. psycopg2 can use this same system for its ARRAY casts 6. the general need for PostgreSQL to have lots of type casts is now mostly in the base PostgreSQL dialect and works independently of a DBAPI being present. dependence on DBAPI symbols that aren't complete / consistent / hashable is removed I was originally going to try to build this into bind_expression(), but it was revealed this worked poorly with custom bind_expression() as well as empty sets. the current impl also doesn't need to run a second expression pass over the POSTCOMPILE sections, which came out better than I originally thought it would. Change-Id: I363e6d593d059add7bcc6d1f6c3f91dd2e683c0c
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/connectors/pyodbc.py11
-rw-r--r--lib/sqlalchemy/dialects/oracle/cx_oracle.py14
-rw-r--r--lib/sqlalchemy/dialects/postgresql/array.py3
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py152
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py30
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pg8000.py95
-rw-r--r--lib/sqlalchemy/dialects/postgresql/psycopg2.py4
-rw-r--r--lib/sqlalchemy/engine/__init__.py1
-rw-r--r--lib/sqlalchemy/engine/base.py3
-rw-r--r--lib/sqlalchemy/engine/default.py37
-rw-r--r--lib/sqlalchemy/engine/interfaces.py70
-rw-r--r--lib/sqlalchemy/sql/compiler.py115
-rw-r--r--lib/sqlalchemy/sql/type_api.py15
-rw-r--r--lib/sqlalchemy/testing/assertsql.py8
-rw-r--r--lib/sqlalchemy/testing/suite/test_types.py6
15 files changed, 296 insertions, 268 deletions
diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py
index 411985b5d..fdaa8981a 100644
--- a/lib/sqlalchemy/connectors/pyodbc.py
+++ b/lib/sqlalchemy/connectors/pyodbc.py
@@ -9,6 +9,7 @@ import re
from . import Connector
from .. import util
+from ..engine import interfaces
class PyODBCConnector(Connector):
@@ -21,15 +22,14 @@ class PyODBCConnector(Connector):
supports_native_decimal = True
default_paramstyle = "named"
- use_setinputsizes = False
-
# for non-DSN connections, this *may* be used to
# hold the desired driver name
pyodbc_driver_name = None
def __init__(self, use_setinputsizes=False, **kw):
super(PyODBCConnector, self).__init__(**kw)
- self.use_setinputsizes = use_setinputsizes
+ if use_setinputsizes:
+ self.bind_typing = interfaces.BindTyping.SETINPUTSIZES
@classmethod
def dbapi(cls):
@@ -160,8 +160,9 @@ class PyODBCConnector(Connector):
# for types such as pyodbc.SQL_WLONGVARCHAR, which is the datatype
# that ticket #5649 is targeting.
- # NOTE: as of #6058, this won't be called if the use_setinputsizes flag
- # is False, or if no types were specified in list_of_tuples
+ # NOTE: as of #6058, this won't be called if the use_setinputsizes
+ # parameter were not passed to the dialect, or if no types were
+ # specified in list_of_tuples
cursor.setinputsizes(
[
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
index 2cfcb0e5c..672cbd7d9 100644
--- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py
+++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -446,6 +446,7 @@ from ... import processors
from ... import types as sqltypes
from ... import util
from ...engine import cursor as _cursor
+from ...engine import interfaces
class _OracleInteger(sqltypes.Integer):
@@ -783,8 +784,6 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
self._generate_cursor_outputtype_handler()
- self.include_set_input_sizes = self.dialect._include_setinputsizes
-
def post_exec(self):
if self.compiled and self.out_parameters and self.compiled.returning:
# create a fake cursor result from the out parameters. unlike
@@ -833,7 +832,7 @@ class OracleDialect_cx_oracle(OracleDialect):
supports_sane_rowcount = True
supports_sane_multi_rowcount = True
- use_setinputsizes = True
+ bind_typing = interfaces.BindTyping.SETINPUTSIZES
driver = "cx_oracle"
@@ -909,7 +908,6 @@ class OracleDialect_cx_oracle(OracleDialect):
cx_Oracle = self.dbapi
if cx_Oracle is None:
- self._include_setinputsizes = {}
self.cx_oracle_ver = (0, 0, 0)
else:
self.cx_oracle_ver = self._parse_cx_oracle_ver(cx_Oracle.version)
@@ -925,7 +923,7 @@ class OracleDialect_cx_oracle(OracleDialect):
)
self._cursor_var_unicode_kwargs = util.immutabledict()
- self._include_setinputsizes = {
+ self.include_set_input_sizes = {
cx_Oracle.DATETIME,
cx_Oracle.NCLOB,
cx_Oracle.CLOB,
@@ -935,9 +933,9 @@ class OracleDialect_cx_oracle(OracleDialect):
cx_Oracle.BLOB,
cx_Oracle.FIXED_CHAR,
cx_Oracle.TIMESTAMP,
- _OracleInteger,
- _OracleBINARY_FLOAT,
- _OracleBINARY_DOUBLE,
+ int, # _OracleInteger,
+ # _OracleBINARY_FLOAT, _OracleBINARY_DOUBLE,
+ cx_Oracle.NATIVE_FLOAT,
}
self._paramval = lambda value: value.getvalue()
diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py
index 0cb574dac..ff590c1b0 100644
--- a/lib/sqlalchemy/dialects/postgresql/array.py
+++ b/lib/sqlalchemy/dialects/postgresql/array.py
@@ -330,9 +330,6 @@ class ARRAY(sqltypes.ARRAY):
and self.item_type.native_enum
)
- def bind_expression(self, bindvalue):
- return bindvalue
-
def bind_processor(self, dialect):
item_proc = self.item_type.dialect_impl(dialect).bind_processor(
dialect
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
index fe1f9fd5a..4ac0971e5 100644
--- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -134,32 +134,28 @@ except ImportError:
_python_UUID = None
+class AsyncpgString(sqltypes.String):
+ render_bind_cast = True
+
+
class AsyncpgTime(sqltypes.Time):
- def get_dbapi_type(self, dbapi):
- return dbapi.TIME
+ render_bind_cast = True
class AsyncpgDate(sqltypes.Date):
- def get_dbapi_type(self, dbapi):
- return dbapi.DATE
+ render_bind_cast = True
class AsyncpgDateTime(sqltypes.DateTime):
- def get_dbapi_type(self, dbapi):
- if self.timezone:
- return dbapi.TIMESTAMP_W_TZ
- else:
- return dbapi.TIMESTAMP
+ render_bind_cast = True
class AsyncpgBoolean(sqltypes.Boolean):
- def get_dbapi_type(self, dbapi):
- return dbapi.BOOLEAN
+ render_bind_cast = True
class AsyncPgInterval(INTERVAL):
- def get_dbapi_type(self, dbapi):
- return dbapi.INTERVAL
+ render_bind_cast = True
@classmethod
def adapt_emulated_to_native(cls, interval, **kw):
@@ -168,49 +164,45 @@ class AsyncPgInterval(INTERVAL):
class AsyncPgEnum(ENUM):
- def get_dbapi_type(self, dbapi):
- return dbapi.ENUM
+ render_bind_cast = True
class AsyncpgInteger(sqltypes.Integer):
- def get_dbapi_type(self, dbapi):
- return dbapi.INTEGER
+ render_bind_cast = True
class AsyncpgBigInteger(sqltypes.BigInteger):
- def get_dbapi_type(self, dbapi):
- return dbapi.BIGINTEGER
+ render_bind_cast = True
class AsyncpgJSON(json.JSON):
- def get_dbapi_type(self, dbapi):
- return dbapi.JSON
+ render_bind_cast = True
def result_processor(self, dialect, coltype):
return None
class AsyncpgJSONB(json.JSONB):
- def get_dbapi_type(self, dbapi):
- return dbapi.JSONB
+ render_bind_cast = True
def result_processor(self, dialect, coltype):
return None
class AsyncpgJSONIndexType(sqltypes.JSON.JSONIndexType):
- def get_dbapi_type(self, dbapi):
- raise NotImplementedError("should not be here")
+ pass
class AsyncpgJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
- def get_dbapi_type(self, dbapi):
- return dbapi.INTEGER
+ __visit_name__ = "json_int_index"
+
+ render_bind_cast = True
class AsyncpgJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
- def get_dbapi_type(self, dbapi):
- return dbapi.STRING
+ __visit_name__ = "json_str_index"
+
+ render_bind_cast = True
class AsyncpgJSONPathType(json.JSONPathType):
@@ -224,8 +216,7 @@ class AsyncpgJSONPathType(json.JSONPathType):
class AsyncpgUUID(UUID):
- def get_dbapi_type(self, dbapi):
- return dbapi.UUID
+ render_bind_cast = True
def bind_processor(self, dialect):
if not self.as_uuid and dialect.use_native_uuid:
@@ -249,8 +240,7 @@ class AsyncpgUUID(UUID):
class AsyncpgNumeric(sqltypes.Numeric):
- def get_dbapi_type(self, dbapi):
- return dbapi.NUMBER
+ render_bind_cast = True
def bind_processor(self, dialect):
return None
@@ -281,18 +271,16 @@ class AsyncpgNumeric(sqltypes.Numeric):
class AsyncpgFloat(AsyncpgNumeric):
- def get_dbapi_type(self, dbapi):
- return dbapi.FLOAT
+ __visit_name__ = "float"
+ render_bind_cast = True
class AsyncpgREGCLASS(REGCLASS):
- def get_dbapi_type(self, dbapi):
- return dbapi.STRING
+ render_bind_cast = True
class AsyncpgOID(OID):
- def get_dbapi_type(self, dbapi):
- return dbapi.INTEGER
+ render_bind_cast = True
class PGExecutionContext_asyncpg(PGExecutionContext):
@@ -317,11 +305,6 @@ class PGExecutionContext_asyncpg(PGExecutionContext):
if not self.compiled:
return
- # we have to exclude ENUM because "enum" not really a "type"
- # we can cast to, it has to be the name of the type itself.
- # for now we just omit it from casting
- self.exclude_set_input_sizes = {AsyncAdapt_asyncpg_dbapi.ENUM}
-
def create_server_side_cursor(self):
return self._dbapi_connection.cursor(server_side=True)
@@ -367,15 +350,7 @@ class AsyncAdapt_asyncpg_cursor:
self._adapt_connection._handle_exception(error)
def _parameter_placeholders(self, params):
- if not self._inputsizes:
- return tuple("$%d" % idx for idx, _ in enumerate(params, 1))
- else:
- return tuple(
- "$%d::%s" % (idx, typ) if typ else "$%d" % idx
- for idx, typ in enumerate(
- (_pg_types.get(typ) for typ in self._inputsizes), 1
- )
- )
+ return tuple(f"${idx:d}" for idx, _ in enumerate(params, 1))
async def _prepare_and_execute(self, operation, parameters):
adapt_connection = self._adapt_connection
@@ -464,7 +439,7 @@ class AsyncAdapt_asyncpg_cursor:
)
def setinputsizes(self, *inputsizes):
- self._inputsizes = inputsizes
+ raise NotImplementedError()
def __iter__(self):
while self._rows:
@@ -798,6 +773,12 @@ class AsyncAdapt_asyncpg_dbapi:
"all prepared caches in response to this exception)",
)
+ # pep-249 datatype placeholders. As of SQLAlchemy 2.0 these aren't
+ # used, however the test suite looks for these in a few cases.
+ STRING = util.symbol("STRING")
+ NUMBER = util.symbol("NUMBER")
+ DATETIME = util.symbol("DATETIME")
+
@util.memoized_property
def _asyncpg_error_translate(self):
import asyncpg
@@ -814,50 +795,6 @@ class AsyncAdapt_asyncpg_dbapi:
def Binary(self, value):
return value
- STRING = util.symbol("STRING")
- TIMESTAMP = util.symbol("TIMESTAMP")
- TIMESTAMP_W_TZ = util.symbol("TIMESTAMP_W_TZ")
- TIME = util.symbol("TIME")
- DATE = util.symbol("DATE")
- INTERVAL = util.symbol("INTERVAL")
- NUMBER = util.symbol("NUMBER")
- FLOAT = util.symbol("FLOAT")
- BOOLEAN = util.symbol("BOOLEAN")
- INTEGER = util.symbol("INTEGER")
- BIGINTEGER = util.symbol("BIGINTEGER")
- BYTES = util.symbol("BYTES")
- DECIMAL = util.symbol("DECIMAL")
- JSON = util.symbol("JSON")
- JSONB = util.symbol("JSONB")
- ENUM = util.symbol("ENUM")
- UUID = util.symbol("UUID")
- BYTEA = util.symbol("BYTEA")
-
- DATETIME = TIMESTAMP
- BINARY = BYTEA
-
-
-_pg_types = {
- AsyncAdapt_asyncpg_dbapi.STRING: "varchar",
- AsyncAdapt_asyncpg_dbapi.TIMESTAMP: "timestamp",
- AsyncAdapt_asyncpg_dbapi.TIMESTAMP_W_TZ: "timestamp with time zone",
- AsyncAdapt_asyncpg_dbapi.DATE: "date",
- AsyncAdapt_asyncpg_dbapi.TIME: "time",
- AsyncAdapt_asyncpg_dbapi.INTERVAL: "interval",
- AsyncAdapt_asyncpg_dbapi.NUMBER: "numeric",
- AsyncAdapt_asyncpg_dbapi.FLOAT: "float",
- AsyncAdapt_asyncpg_dbapi.BOOLEAN: "bool",
- AsyncAdapt_asyncpg_dbapi.INTEGER: "integer",
- AsyncAdapt_asyncpg_dbapi.BIGINTEGER: "bigint",
- AsyncAdapt_asyncpg_dbapi.BYTES: "bytes",
- AsyncAdapt_asyncpg_dbapi.DECIMAL: "decimal",
- AsyncAdapt_asyncpg_dbapi.JSON: "json",
- AsyncAdapt_asyncpg_dbapi.JSONB: "jsonb",
- AsyncAdapt_asyncpg_dbapi.ENUM: "enum",
- AsyncAdapt_asyncpg_dbapi.UUID: "uuid",
- AsyncAdapt_asyncpg_dbapi.BYTEA: "bytea",
-}
-
class PGDialect_asyncpg(PGDialect):
driver = "asyncpg"
@@ -865,19 +802,20 @@ class PGDialect_asyncpg(PGDialect):
supports_server_side_cursors = True
+ render_bind_cast = True
+
default_paramstyle = "format"
supports_sane_multi_rowcount = False
execution_ctx_cls = PGExecutionContext_asyncpg
statement_compiler = PGCompiler_asyncpg
preparer = PGIdentifierPreparer_asyncpg
- use_setinputsizes = True
-
use_native_uuid = True
colspecs = util.update_copy(
PGDialect.colspecs,
{
+ sqltypes.String: AsyncpgString,
sqltypes.Time: AsyncpgTime,
sqltypes.Date: AsyncpgDate,
sqltypes.DateTime: AsyncpgDateTime,
@@ -977,20 +915,6 @@ class PGDialect_asyncpg(PGDialect):
e, self.dbapi.InterfaceError
) and "connection is closed" in str(e)
- def do_set_input_sizes(self, cursor, list_of_tuples, context):
- if self.positional:
- cursor.setinputsizes(
- *[dbtype for key, dbtype, sqltype in list_of_tuples]
- )
- else:
- cursor.setinputsizes(
- **{
- key: dbtype
- for key, dbtype, sqltype in list_of_tuples
- if dbtype
- }
- )
-
async def setup_asyncpg_json_codec(self, conn):
"""set up JSON codec for asyncpg.
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 583d9c263..800b289fb 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -1388,6 +1388,7 @@ from ... import sql
from ... import util
from ...engine import characteristics
from ...engine import default
+from ...engine import interfaces
from ...engine import reflection
from ...sql import coercions
from ...sql import compiler
@@ -2041,16 +2042,6 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
self.drop(bind=bind, checkfirst=checkfirst)
-class _ColonCast(elements.CompilerColumnElement):
- __visit_name__ = "colon_cast"
- __slots__ = ("type", "clause", "typeclause")
-
- def __init__(self, expression, type_):
- self.type = type_
- self.clause = expression
- self.typeclause = elements.TypeClause(type_)
-
-
colspecs = {
sqltypes.ARRAY: _array.ARRAY,
sqltypes.Interval: INTERVAL,
@@ -2106,11 +2097,12 @@ ischema_names = {
class PGCompiler(compiler.SQLCompiler):
- def visit_colon_cast(self, element, **kw):
- return "%s::%s" % (
- element.clause._compiler_dispatch(self, **kw),
- element.typeclause._compiler_dispatch(self, **kw),
- )
+ def render_bind_cast(self, type_, dbapi_type, sqltext):
+ return f"""{sqltext}::{
+ self.dialect.type_compiler.process(
+ dbapi_type, identifier_preparer=self.preparer
+ )
+ }"""
def visit_array(self, element, **kw):
return "ARRAY[%s]" % self.visit_clauselist(element, **kw)
@@ -2854,6 +2846,12 @@ class PGTypeCompiler(compiler.GenericTypeCompiler):
def visit_TSTZRANGE(self, type_, **kw):
return "TSTZRANGE"
+ def visit_json_int_index(self, type_, **kw):
+ return "INT"
+
+ def visit_json_str_index(self, type_, **kw):
+ return "TEXT"
+
def visit_datetime(self, type_, **kw):
return self.visit_TIMESTAMP(type_, **kw)
@@ -3121,6 +3119,8 @@ class PGDialect(default.DefaultDialect):
max_identifier_length = 63
supports_sane_rowcount = True
+ bind_typing = interfaces.BindTyping.RENDER_CASTS
+
supports_native_enum = True
supports_native_boolean = True
supports_smallserial = True
diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py
index 324007e7e..e849d0499 100644
--- a/lib/sqlalchemy/dialects/postgresql/pg8000.py
+++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py
@@ -94,7 +94,6 @@ import re
from uuid import UUID as _python_UUID
from .array import ARRAY as PGARRAY
-from .base import _ColonCast
from .base import _DECIMAL_TYPES
from .base import _FLOAT_TYPES
from .base import _INT_TYPES
@@ -115,7 +114,13 @@ from ... import util
from ...sql.elements import quoted_name
+class _PGString(sqltypes.String):
+ render_bind_cast = True
+
+
class _PGNumeric(sqltypes.Numeric):
+ render_bind_cast = True
+
def result_processor(self, dialect, coltype):
if self.asdecimal:
if coltype in _FLOAT_TYPES:
@@ -141,26 +146,29 @@ class _PGNumeric(sqltypes.Numeric):
)
+class _PGFloat(_PGNumeric):
+ __visit_name__ = "float"
+ render_bind_cast = True
+
+
class _PGNumericNoBind(_PGNumeric):
def bind_processor(self, dialect):
return None
class _PGJSON(JSON):
+ render_bind_cast = True
+
def result_processor(self, dialect, coltype):
return None
- def get_dbapi_type(self, dbapi):
- return dbapi.JSON
-
class _PGJSONB(JSONB):
+ render_bind_cast = True
+
def result_processor(self, dialect, coltype):
return None
- def get_dbapi_type(self, dbapi):
- return dbapi.JSONB
-
class _PGJSONIndexType(sqltypes.JSON.JSONIndexType):
def get_dbapi_type(self, dbapi):
@@ -168,21 +176,26 @@ class _PGJSONIndexType(sqltypes.JSON.JSONIndexType):
class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
- def get_dbapi_type(self, dbapi):
- return dbapi.INTEGER
+ __visit_name__ = "json_int_index"
+
+ render_bind_cast = True
class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
- def get_dbapi_type(self, dbapi):
- return dbapi.STRING
+ __visit_name__ = "json_str_index"
+
+ render_bind_cast = True
class _PGJSONPathType(JSONPathType):
- def get_dbapi_type(self, dbapi):
- return 1009
+ pass
+
+ # DBAPI type 1009
class _PGUUID(UUID):
+ render_bind_cast = True
+
def bind_processor(self, dialect):
if not self.as_uuid:
@@ -210,6 +223,8 @@ class _PGEnum(ENUM):
class _PGInterval(INTERVAL):
+ render_bind_cast = True
+
def get_dbapi_type(self, dbapi):
return dbapi.INTERVAL
@@ -219,48 +234,39 @@ class _PGInterval(INTERVAL):
class _PGTimeStamp(sqltypes.DateTime):
- def get_dbapi_type(self, dbapi):
- if self.timezone:
- # TIMESTAMPTZOID
- return 1184
- else:
- # TIMESTAMPOID
- return 1114
+ render_bind_cast = True
+
+
+class _PGDate(sqltypes.Date):
+ render_bind_cast = True
class _PGTime(sqltypes.Time):
- def get_dbapi_type(self, dbapi):
- return dbapi.TIME
+ render_bind_cast = True
class _PGInteger(sqltypes.Integer):
- def get_dbapi_type(self, dbapi):
- return dbapi.INTEGER
+ render_bind_cast = True
class _PGSmallInteger(sqltypes.SmallInteger):
- def get_dbapi_type(self, dbapi):
- return dbapi.INTEGER
+ render_bind_cast = True
class _PGNullType(sqltypes.NullType):
- def get_dbapi_type(self, dbapi):
- return dbapi.NULLTYPE
+ pass
class _PGBigInteger(sqltypes.BigInteger):
- def get_dbapi_type(self, dbapi):
- return dbapi.BIGINTEGER
+ render_bind_cast = True
class _PGBoolean(sqltypes.Boolean):
- def get_dbapi_type(self, dbapi):
- return dbapi.BOOLEAN
+ render_bind_cast = True
class _PGARRAY(PGARRAY):
- def bind_expression(self, bindvalue):
- return _ColonCast(bindvalue, self)
+ render_bind_cast = True
_server_side_id = util.counter()
@@ -362,7 +368,7 @@ class PGDialect_pg8000(PGDialect):
preparer = PGIdentifierPreparer_pg8000
supports_server_side_cursors = True
- use_setinputsizes = True
+ render_bind_cast = True
# reversed as of pg8000 1.16.6. 1.16.5 and lower
# are no longer compatible
@@ -372,8 +378,9 @@ class PGDialect_pg8000(PGDialect):
colspecs = util.update_copy(
PGDialect.colspecs,
{
+ sqltypes.String: _PGString,
sqltypes.Numeric: _PGNumericNoBind,
- sqltypes.Float: _PGNumeric,
+ sqltypes.Float: _PGFloat,
sqltypes.JSON: _PGJSON,
sqltypes.Boolean: _PGBoolean,
sqltypes.NullType: _PGNullType,
@@ -386,6 +393,8 @@ class PGDialect_pg8000(PGDialect):
sqltypes.Interval: _PGInterval,
INTERVAL: _PGInterval,
sqltypes.DateTime: _PGTimeStamp,
+ sqltypes.DateTime: _PGTimeStamp,
+ sqltypes.Date: _PGDate,
sqltypes.Time: _PGTime,
sqltypes.Integer: _PGInteger,
sqltypes.SmallInteger: _PGSmallInteger,
@@ -517,20 +526,6 @@ class PGDialect_pg8000(PGDialect):
cursor.execute("COMMIT")
cursor.close()
- def do_set_input_sizes(self, cursor, list_of_tuples, context):
- if self.positional:
- cursor.setinputsizes(
- *[dbtype for key, dbtype, sqltype in list_of_tuples]
- )
- else:
- cursor.setinputsizes(
- **{
- key: dbtype
- for key, dbtype, sqltype in list_of_tuples
- if dbtype
- }
- )
-
def do_begin_twophase(self, connection, xid):
connection.connection.tpc_begin((0, xid, ""))
diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
index f62830a0d..19c01d208 100644
--- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py
+++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py
@@ -449,7 +449,6 @@ import re
from uuid import UUID as _python_UUID
from .array import ARRAY as PGARRAY
-from .base import _ColonCast
from .base import _DECIMAL_TYPES
from .base import _FLOAT_TYPES
from .base import _INT_TYPES
@@ -516,8 +515,7 @@ class _PGHStore(HSTORE):
class _PGARRAY(PGARRAY):
- def bind_expression(self, bindvalue):
- return _ColonCast(bindvalue, self)
+ render_bind_cast = True
class _PGJSON(JSON):
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py
index ba57eee51..5f4c5be47 100644
--- a/lib/sqlalchemy/engine/__init__.py
+++ b/lib/sqlalchemy/engine/__init__.py
@@ -33,6 +33,7 @@ from .cursor import CursorResult
from .cursor import FullyBufferedResultProxy
from .cursor import ResultProxy
from .interfaces import AdaptedConnection
+from .interfaces import BindTyping
from .interfaces import Compiled
from .interfaces import Connectable
from .interfaces import CreateEnginePlugin
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 389270e45..61ef29d4a 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -9,6 +9,7 @@ from __future__ import with_statement
import contextlib
import sys
+from .interfaces import BindTyping
from .interfaces import Connectable
from .interfaces import ConnectionEventsTarget
from .interfaces import ExceptionContext
@@ -1486,7 +1487,7 @@ class Connection(Connectable):
context.pre_exec()
- if dialect.use_setinputsizes:
+ if dialect.bind_typing is BindTyping.SETINPUTSIZES:
context._set_input_sizes()
cursor, statement, parameters = (
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index 3af24d913..d36ed6e65 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -52,9 +52,13 @@ class DefaultDialect(interfaces.Dialect):
supports_alter = True
supports_comments = False
inline_comments = False
- use_setinputsizes = False
supports_statement_cache = True
+ bind_typing = interfaces.BindTyping.NONE
+
+ include_set_input_sizes = None
+ exclude_set_input_sizes = None
+
# the first value we'd get for an autoincrement
# column.
default_sequence_base = 1
@@ -260,6 +264,15 @@ class DefaultDialect(interfaces.Dialect):
else:
self.server_side_cursors = True
+ if getattr(self, "use_setinputsizes", False):
+ util.warn_deprecated(
+ "The dialect-level use_setinputsizes attribute is "
+ "deprecated. Please use "
+ "bind_typing = BindTyping.SETINPUTSIZES",
+ "2.0",
+ )
+ self.bind_typing = interfaces.BindTyping.SETINPUTSIZES
+
self.encoding = encoding
self.positional = False
self._ischema = None
@@ -287,6 +300,10 @@ class DefaultDialect(interfaces.Dialect):
self.label_length = label_length
self.compiler_linting = compiler_linting
+ @util.memoized_property
+ def _bind_typing_render_casts(self):
+ return self.bind_typing is interfaces.BindTyping.RENDER_CASTS
+
def _ensure_has_table_connection(self, arg):
if not isinstance(arg, Connection):
@@ -736,9 +753,6 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
returned_default_rows = None
execution_options = util.immutabledict()
- include_set_input_sizes = None
- exclude_set_input_sizes = None
-
cursor_fetch_strategy = _cursor._DEFAULT_FETCH
cache_stats = None
@@ -1373,8 +1387,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
style of ``setinputsizes()`` on the cursor, using DB-API types
from the bind parameter's ``TypeEngine`` objects.
- This method only called by those dialects which require it,
- currently cx_oracle, asyncpg and pg8000.
+ This method only called by those dialects which set
+ the :attr:`.Dialect.bind_typing` attribute to
+ :attr:`.BindTyping.SETINPUTSIZES`. cx_Oracle is the only DBAPI
+ that requires setinputsizes(), pyodbc offers it as an option.
+
+ Prior to SQLAlchemy 2.0, the setinputsizes() approach was also used
+ for pg8000 and asyncpg, which has been changed to inline rendering
+ of casts.
"""
if self.isddl or self.is_text:
@@ -1382,10 +1402,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
compiled = self.compiled
- inputsizes = compiled._get_set_input_sizes_lookup(
- include_types=self.include_set_input_sizes,
- exclude_types=self.exclude_set_input_sizes,
- )
+ inputsizes = compiled._get_set_input_sizes_lookup()
if inputsizes is None:
return
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
index 6772a27bd..251d01c5e 100644
--- a/lib/sqlalchemy/engine/interfaces.py
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -7,10 +7,60 @@
"""Define core interfaces used by the engine system."""
+from enum import Enum
+
from ..sql.compiler import Compiled # noqa
from ..sql.compiler import TypeCompiler # noqa
+class BindTyping(Enum):
+ """Define different methods of passing typing information for
+ bound parameters in a statement to the database driver.
+
+ .. versionadded:: 2.0
+
+ """
+
+ NONE = 1
+ """No steps are taken to pass typing information to the database driver.
+
+ This is the default behavior for databases such as SQLite, MySQL / MariaDB,
+ SQL Server.
+
+ """
+
+ SETINPUTSIZES = 2
+ """Use the pep-249 setinputsizes method.
+
+ This is only implemented for DBAPIs that support this method and for which
+ the SQLAlchemy dialect has the appropriate infrastructure for that
+ dialect set up. Current dialects include cx_Oracle as well as
+ optional support for SQL Server using pyodbc.
+
+ When using setinputsizes, dialects also have a means of only using the
+ method for certain datatypes using include/exclude lists.
+
+ When SETINPUTSIZES is used, the :meth:`.Dialect.do_set_input_sizes` method
+ is called for each statement executed which has bound parameters.
+
+ """
+
+ RENDER_CASTS = 3
+ """Render casts or other directives in the SQL string.
+
+ This method is used for all PostgreSQL dialects, including asyncpg,
+ pg8000, psycopg, psycopg2. Dialects which implement this can choose
+ which kinds of datatypes are explicitly cast in SQL statements and which
+ aren't.
+
+ When RENDER_CASTS is used, the compiler will invoke the
+ :meth:`.SQLCompiler.render_bind_cast` method for each
+ :class:`.BindParameter` object whose dialect-level type sets the
+ :attr:`.TypeEngine.render_bind_cast` attribute.
+
+ """
+
+
class Dialect:
"""Define the behavior of a specific database and DB-API combination.
@@ -156,6 +206,16 @@ class Dialect:
"""
+ bind_typing = BindTyping.NONE
+ """define a means of passing typing information to the database and/or
+ driver for bound parameters.
+
+ See :class:`.BindTyping` for values.
+
+ ..versionadded:: 2.0
+
+ """
+
def create_connect_args(self, url):
"""Build DB-API compatible connection arguments.
@@ -587,7 +647,9 @@ class Dialect:
def do_set_input_sizes(self, cursor, list_of_tuples, context):
"""invoke the cursor.setinputsizes() method with appropriate arguments
- This hook is called if the dialect.use_inputsizes flag is set to True.
+ This hook is called if the :attr:`.Dialect.bind_typing` attribute is
+ set to the
+ :attr:`.BindTyping.SETINPUTSIZES` value.
Parameter data is passed in a list of tuples (paramname, dbtype,
sqltype), where ``paramname`` is the key of the parameter in the
statement, ``dbtype`` is the DBAPI datatype and ``sqltype`` is the
@@ -595,6 +657,12 @@ class Dialect:
.. versionadded:: 1.4
+ .. versionchanged:: 2.0 - setinputsizes mode is now enabled by
+ setting :attr:`.Dialect.bind_typing` to
+ :attr:`.BindTyping.SETINPUTSIZES`. Dialects which accept
+ a ``use_setinputsizes`` parameter should set this value
+ appropriately.
+
"""
raise NotImplementedError()
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 29aa57faa..710c62c59 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -227,6 +227,7 @@ FUNCTIONS = {
functions.grouping_sets: "GROUPING SETS",
}
+
EXTRACT_MAP = {
"month": "month",
"day": "day",
@@ -1036,57 +1037,28 @@ class SQLCompiler(Compiled):
return pd
@util.memoized_instancemethod
- def _get_set_input_sizes_lookup(
- self, include_types=None, exclude_types=None
- ):
- if not hasattr(self, "bind_names"):
- return None
-
+ def _get_set_input_sizes_lookup(self):
dialect = self.dialect
- dbapi = self.dialect.dbapi
- # _unwrapped_dialect_impl() is necessary so that we get the
- # correct dialect type for a custom TypeDecorator, or a Variant,
- # which is also a TypeDecorator. Special types like Interval,
- # that use TypeDecorator but also might be mapped directly
- # for a dialect impl, also subclass Emulated first which overrides
- # this behavior in those cases to behave like the default.
+ include_types = dialect.include_set_input_sizes
+ exclude_types = dialect.exclude_set_input_sizes
- if include_types is None and exclude_types is None:
+ dbapi = dialect.dbapi
- def _lookup_type(typ):
- dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi)
- return dbtype
+ def lookup_type(typ):
+ dbtype = typ._unwrapped_dialect_impl(dialect).get_dbapi_type(dbapi)
- else:
-
- def _lookup_type(typ):
- # note we get dbtype from the possibly TypeDecorator-wrapped
- # dialect_impl, but the dialect_impl itself that we use for
- # include/exclude is the unwrapped version.
-
- dialect_impl = typ._unwrapped_dialect_impl(dialect)
-
- dbtype = typ.dialect_impl(dialect).get_dbapi_type(dbapi)
-
- if (
- dbtype is not None
- and (
- exclude_types is None
- or dbtype not in exclude_types
- and type(dialect_impl) not in exclude_types
- )
- and (
- include_types is None
- or dbtype in include_types
- or type(dialect_impl) in include_types
- )
- ):
- return dbtype
- else:
- return None
+ if (
+ dbtype is not None
+ and (exclude_types is None or dbtype not in exclude_types)
+ and (include_types is None or dbtype in include_types)
+ ):
+ return dbtype
+ else:
+ return None
inputsizes = {}
+
literal_execute_params = self.literal_execute_params
for bindparam in self.bind_names:
@@ -1095,10 +1067,10 @@ class SQLCompiler(Compiled):
if bindparam.type._is_tuple_type:
inputsizes[bindparam] = [
- _lookup_type(typ) for typ in bindparam.type.types
+ lookup_type(typ) for typ in bindparam.type.types
]
else:
- inputsizes[bindparam] = _lookup_type(bindparam.type)
+ inputsizes[bindparam] = lookup_type(bindparam.type)
return inputsizes
@@ -2061,7 +2033,25 @@ class SQLCompiler(Compiled):
parameter, values
)
- typ_dialect_impl = parameter.type._unwrapped_dialect_impl(self.dialect)
+ dialect = self.dialect
+ typ_dialect_impl = parameter.type._unwrapped_dialect_impl(dialect)
+
+ if (
+ self.dialect._bind_typing_render_casts
+ and typ_dialect_impl.render_bind_cast
+ ):
+
+ def _render_bindtemplate(name):
+ return self.render_bind_cast(
+ parameter.type,
+ typ_dialect_impl,
+ self.bindtemplate % {"name": name},
+ )
+
+ else:
+
+ def _render_bindtemplate(name):
+ return self.bindtemplate % {"name": name}
if not values:
to_update = []
@@ -2088,14 +2078,16 @@ class SQLCompiler(Compiled):
for i, tuple_element in enumerate(values, 1)
for j, value in enumerate(tuple_element, 1)
]
+
replacement_expression = (
- "VALUES " if self.dialect.tuple_in_values else ""
+ "VALUES " if dialect.tuple_in_values else ""
) + ", ".join(
"(%s)"
% (
", ".join(
- self.bindtemplate
- % {"name": to_update[i * len(tuple_element) + j][0]}
+ _render_bindtemplate(
+ to_update[i * len(tuple_element) + j][0]
+ )
for j, value in enumerate(tuple_element)
)
)
@@ -2107,7 +2099,7 @@ class SQLCompiler(Compiled):
for i, value in enumerate(values, 1)
]
replacement_expression = ", ".join(
- self.bindtemplate % {"name": key} for key, value in to_update
+ _render_bindtemplate(key) for key, value in to_update
)
return to_update, replacement_expression
@@ -2376,6 +2368,7 @@ class SQLCompiler(Compiled):
m = re.match(
r"^(.*)\(__\[POSTCOMPILE_(\S+?)\]\)(.*)$", wrapped
)
+ assert m, "unexpected format for expanding parameter"
wrapped = "(__[POSTCOMPILE_%s~~%s~~REPL~~%s~~])" % (
m.group(2),
m.group(1),
@@ -2463,13 +2456,18 @@ class SQLCompiler(Compiled):
name,
post_compile=post_compile,
expanding=bindparam.expanding,
+ bindparam_type=bindparam.type,
**kwargs
)
if bindparam.expanding:
ret = "(%s)" % ret
+
return ret
+ def render_bind_cast(self, type_, dbapi_type, sqltext):
+ raise NotImplementedError()
+
def render_literal_bindparam(
self, bindparam, render_literal_value=NO_ARG, **kw
):
@@ -2556,6 +2554,7 @@ class SQLCompiler(Compiled):
post_compile=False,
expanding=False,
escaped_from=None,
+ bindparam_type=None,
**kw
):
@@ -2583,8 +2582,18 @@ class SQLCompiler(Compiled):
self.escaped_bind_names[escaped_from] = name
if post_compile:
return "__[POSTCOMPILE_%s]" % name
- else:
- return self.bindtemplate % {"name": name}
+
+ ret = self.bindtemplate % {"name": name}
+
+ if (
+ bindparam_type is not None
+ and self.dialect._bind_typing_render_casts
+ ):
+ type_impl = bindparam_type._unwrapped_dialect_impl(self.dialect)
+ if type_impl.render_bind_cast:
+ ret = self.render_bind_cast(bindparam_type, type_impl, ret)
+
+ return ret
def visit_cte(
self,
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index 01763f266..69c4a4a76 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -51,6 +51,21 @@ class TypeEngine(Traversible):
_is_array = False
_is_type_decorator = False
+ render_bind_cast = False
+ """Render bind casts for :attr:`.BindTyping.RENDER_CASTS` mode.
+
+ If True, this type (usually a dialect level impl type) signals
+ to the compiler that a cast should be rendered around a bound parameter
+ for this type.
+
+ .. versionadded:: 2.0
+
+ .. seealso::
+
+ :class:`.BindTyping`
+
+ """
+
class Comparator(operators.ColumnOperators):
"""Base class for custom comparison operations defined at the
type level. See :attr:`.TypeEngine.comparator_factory`.
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index 485a13f82..6d1dac96f 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -267,6 +267,10 @@ class DialectSQL(CompiledSQL):
def _dialect_adjusted_statement(self, paramstyle):
stmt = re.sub(r"[\n\t]", "", self.statement)
+
+ # temporarily escape out PG double colons
+ stmt = stmt.replace("::", "!!")
+
if paramstyle == "pyformat":
stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
else:
@@ -279,6 +283,10 @@ class DialectSQL(CompiledSQL):
elif paramstyle == "numeric":
repl = None
stmt = re.sub(r":([\w_]+)", repl, stmt)
+
+ # put them back
+ stmt = stmt.replace("!!", "::")
+
return stmt
def _compare_sql(self, execute_observed, received_statement):
diff --git a/lib/sqlalchemy/testing/suite/test_types.py b/lib/sqlalchemy/testing/suite/test_types.py
index 4a5396ed8..c1cbf1ec6 100644
--- a/lib/sqlalchemy/testing/suite/test_types.py
+++ b/lib/sqlalchemy/testing/suite/test_types.py
@@ -491,19 +491,15 @@ class CastTypeDecoratorTest(_LiteralRoundTripFixture, fixtures.TestBase):
impl = String(50)
cache_ok = True
- def get_dbapi_type(self, dbapi):
- return dbapi.NUMBER
-
def column_expression(self, col):
return cast(col, Integer)
def bind_expression(self, col):
- return cast(col, String(50))
+ return cast(type_coerce(col, Integer), String(50))
return StringAsInt()
def test_special_type(self, metadata, connection, string_as_int):
-
type_ = string_as_int
t = Table("t", metadata, Column("x", type_))