summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-10-15 18:18:03 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-10-16 14:28:04 -0400
commit87c24c498cb660e7a8d7d4dd5f630b967f79d3c8 (patch)
tree06f1113c0db30fb1471ac74e69af5a67976b1246 /lib/sqlalchemy
parent41d3e16773e84692b6625ccb67da204b5362d9c3 (diff)
downloadsqlalchemy-87c24c498cb660e7a8d7d4dd5f630b967f79d3c8.tar.gz
Genericize setinputsizes and support pyodbc
Reworked the "setinputsizes()" set of dialect hooks to be correctly extensible for any arbirary DBAPI, by allowing dialects individual hooks that may invoke cursor.setinputsizes() in the appropriate style for that DBAPI. In particular this is intended to support pyodbc's style of usage which is fundamentally different from that of cx_Oracle. Added support for pyodbc. Fixes: #5649 Change-Id: I9f1794f8368bf3663a286932cfe3992dae244a10
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/connectors/pyodbc.py17
-rw-r--r--lib/sqlalchemy/dialects/oracle/cx_oracle.py38
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py18
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pg8000.py18
-rw-r--r--lib/sqlalchemy/engine/base.py3
-rw-r--r--lib/sqlalchemy/engine/default.py132
-rw-r--r--lib/sqlalchemy/engine/events.py7
-rw-r--r--lib/sqlalchemy/engine/interfaces.py15
-rw-r--r--lib/sqlalchemy/event/registry.py1
-rw-r--r--lib/sqlalchemy/sql/compiler.py8
10 files changed, 161 insertions, 96 deletions
diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py
index e1a7c99f4..780161304 100644
--- a/lib/sqlalchemy/connectors/pyodbc.py
+++ b/lib/sqlalchemy/connectors/pyodbc.py
@@ -24,6 +24,8 @@ class PyODBCConnector(Connector):
supports_native_decimal = True
default_paramstyle = "named"
+ use_setinputsizes = True
+
# for non-DSN connections, this *may* be used to
# hold the desired driver name
pyodbc_driver_name = None
@@ -155,6 +157,21 @@ class PyODBCConnector(Connector):
version.append(n)
return tuple(version)
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ # the rules for these types seems a little strange, as you can pass
+ # non-tuples as well as tuples, however it seems to assume "0"
+ # for the subsequent values if you don't pass a tuple which fails
+ # for types such as pyodbc.SQL_WLONGVARCHAR, which is the datatype
+ # that ticket #5649 is targeting.
+ cursor.setinputsizes(
+ [
+ (dbtype, None, None)
+ if not isinstance(dbtype, tuple)
+ else dbtype
+ for key, dbtype, sqltype in list_of_tuples
+ ]
+ )
+
def set_isolation_level(self, connection, level):
# adjust for ConnectionFairy being present
# allows attribute set e.g. "connection.autocommit = True"
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
index d1b69100f..7bde19090 100644
--- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py
+++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -687,15 +687,12 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
if self.compiled._quoted_bind_names:
self._setup_quoted_bind_names()
- self.set_input_sizes(
- self.compiled._quoted_bind_names,
- include_types=self.dialect._include_setinputsizes,
- )
-
self._generate_out_parameter_vars()
self._generate_cursor_outputtype_handler()
+ self.include_set_input_sizes = self.dialect._include_setinputsizes
+
def post_exec(self):
if self.compiled and self.out_parameters and self.compiled.returning:
# create a fake cursor result from the out parameters. unlike
@@ -746,6 +743,8 @@ class OracleDialect_cx_oracle(OracleDialect):
supports_unicode_statements = True
supports_unicode_binds = True
+ use_setinputsizes = True
+
driver = "cx_oracle"
colspecs = {
@@ -1172,6 +1171,35 @@ class OracleDialect_cx_oracle(OracleDialect):
if oci_prepared:
self.do_commit(connection.connection)
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ if self.positional:
+ # not usually used, here to support if someone is modifying
+ # the dialect to use positional style
+ cursor.setinputsizes(
+ *[dbtype for key, dbtype, sqltype in list_of_tuples]
+ )
+ else:
+ collection = (
+ (key, dbtype)
+ for key, dbtype, sqltype in list_of_tuples
+ if dbtype
+ )
+ if context and context.compiled:
+ quoted_bind_names = context.compiled._quoted_bind_names
+ collection = (
+ (quoted_bind_names.get(key, key), dbtype)
+ for key, dbtype in collection
+ )
+
+ if not self.supports_unicode_binds:
+ # oracle 8 only
+ collection = (
+ (self.dialect._encoder(key)[0], dbtype)
+ for key, dbtype in collection
+ )
+
+ cursor.setinputsizes(**{key: dbtype for key, dbtype in collection})
+
def do_recover_twophase(self, connection):
connection.info.pop("cx_oracle_prepared", None)
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
index a4937d0d2..7d679731b 100644
--- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -245,7 +245,7 @@ class PGExecutionContext_asyncpg(PGExecutionContext):
# we have to exclude ENUM because "enum" not really a "type"
# we can cast to, it has to be the name of the type itself.
# for now we just omit it from casting
- self.set_input_sizes(exclude_types={AsyncAdapt_asyncpg_dbapi.ENUM})
+ self.exclude_set_input_sizes = {AsyncAdapt_asyncpg_dbapi.ENUM}
def create_server_side_cursor(self):
return self._dbapi_connection.cursor(server_side=True)
@@ -687,6 +687,8 @@ class PGDialect_asyncpg(PGDialect):
statement_compiler = PGCompiler_asyncpg
preparer = PGIdentifierPreparer_asyncpg
+ use_setinputsizes = True
+
use_native_uuid = True
colspecs = util.update_copy(
@@ -787,6 +789,20 @@ class PGDialect_asyncpg(PGDialect):
e, self.dbapi.InterfaceError
) and "connection is closed" in str(e)
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ if self.positional:
+ cursor.setinputsizes(
+ *[dbtype for key, dbtype, sqltype in list_of_tuples]
+ )
+ else:
+ cursor.setinputsizes(
+ **{
+ key: dbtype
+ for key, dbtype, sqltype in list_of_tuples
+ if dbtype
+ }
+ )
+
def on_connect(self):
super_connect = super(PGDialect_asyncpg, self).on_connect()
diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py
index b2faa4243..439249157 100644
--- a/lib/sqlalchemy/dialects/postgresql/pg8000.py
+++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py
@@ -234,8 +234,6 @@ class PGExecutionContext_pg8000(PGExecutionContext):
if not self.compiled:
return
- self.set_input_sizes()
-
class PGCompiler_pg8000(PGCompiler):
def visit_mod_binary(self, binary, operator, **kw):
@@ -265,6 +263,8 @@ class PGDialect_pg8000(PGDialect):
statement_compiler = PGCompiler_pg8000
preparer = PGIdentifierPreparer_pg8000
+ use_setinputsizes = True
+
# reversed as of pg8000 1.16.6. 1.16.5 and lower
# are no longer compatible
description_encoding = None
@@ -407,6 +407,20 @@ class PGDialect_pg8000(PGDialect):
cursor.execute("COMMIT")
cursor.close()
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ if self.positional:
+ cursor.setinputsizes(
+ *[dbtype for key, dbtype, sqltype in list_of_tuples]
+ )
+ else:
+ cursor.setinputsizes(
+ **{
+ key: dbtype
+ for key, dbtype, sqltype in list_of_tuples
+ if dbtype
+ }
+ )
+
def do_begin_twophase(self, connection, xid):
connection.connection.tpc_begin((0, xid, ""))
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 4fbdec145..9a5518a96 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -1553,6 +1553,9 @@ class Connection(Connectable):
context.pre_exec()
+ if dialect.use_setinputsizes:
+ context._set_input_sizes()
+
cursor, statement, parameters = (
context.cursor,
context.statement,
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index ff29c3b9d..d63cb4add 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -56,6 +56,7 @@ class DefaultDialect(interfaces.Dialect):
supports_alter = True
supports_comments = False
inline_comments = False
+ use_setinputsizes = False
# the first value we'd get for an autoincrement
# column.
@@ -782,6 +783,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
returned_default_rows = None
execution_options = util.immutabledict()
+ include_set_input_sizes = None
+ exclude_set_input_sizes = None
+
cursor_fetch_strategy = _cursor._DEFAULT_FETCH
cache_stats = None
@@ -1477,9 +1481,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.compiled.postfetch
)
- def set_input_sizes(
- self, translate=None, include_types=None, exclude_types=None
- ):
+ def _set_input_sizes(self):
"""Given a cursor and ClauseParameters, call the appropriate
style of ``setinputsizes()`` on the cursor, using DB-API types
from the bind parameter's ``TypeEngine`` objects.
@@ -1488,14 +1490,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
currently cx_oracle.
"""
- if self.isddl:
- return None
+ if self.isddl or self.is_text:
+ return
inputsizes = self.compiled._get_set_input_sizes_lookup(
- translate=translate,
- include_types=include_types,
- exclude_types=exclude_types,
+ include_types=self.include_set_input_sizes,
+ exclude_types=self.exclude_set_input_sizes,
)
+
if inputsizes is None:
return
@@ -1506,82 +1508,52 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
)
if self.dialect.positional:
- positional_inputsizes = []
- for key in self.compiled.positiontup:
- bindparam = self.compiled.binds[key]
- if bindparam in self.compiled.literal_execute_params:
- continue
-
- if key in self._expanded_parameters:
- if bindparam.type._is_tuple_type:
- num = len(bindparam.type.types)
- dbtypes = inputsizes[bindparam]
- positional_inputsizes.extend(
- [
- dbtypes[idx % num]
- for idx, key in enumerate(
- self._expanded_parameters[key]
- )
- ]
- )
- else:
- dbtype = inputsizes.get(bindparam, None)
- positional_inputsizes.extend(
- dbtype for dbtype in self._expanded_parameters[key]
- )
- else:
- dbtype = inputsizes[bindparam]
- positional_inputsizes.append(dbtype)
- try:
- self.cursor.setinputsizes(*positional_inputsizes)
- except BaseException as e:
- self.root_connection._handle_dbapi_exception(
- e, None, None, None, self
- )
+ items = [
+ (key, self.compiled.binds[key])
+ for key in self.compiled.positiontup
+ ]
else:
- keyword_inputsizes = {}
- for bindparam, key in self.compiled.bind_names.items():
- if bindparam in self.compiled.literal_execute_params:
- continue
-
- if key in self._expanded_parameters:
- if bindparam.type._is_tuple_type:
- num = len(bindparam.type.types)
- dbtypes = inputsizes[bindparam]
- keyword_inputsizes.update(
- [
- (key, dbtypes[idx % num])
- for idx, key in enumerate(
- self._expanded_parameters[key]
- )
- ]
+ items = [
+ (key, bindparam)
+ for bindparam, key in self.compiled.bind_names.items()
+ ]
+
+ generic_inputsizes = []
+ for key, bindparam in items:
+ if bindparam in self.compiled.literal_execute_params:
+ continue
+
+ if key in self._expanded_parameters:
+ if bindparam.type._is_tuple_type:
+ num = len(bindparam.type.types)
+ dbtypes = inputsizes[bindparam]
+ generic_inputsizes.extend(
+ (
+ paramname,
+ dbtypes[idx % num],
+ bindparam.type.types[idx % num],
)
- else:
- dbtype = inputsizes.get(bindparam, None)
- if dbtype is not None:
- keyword_inputsizes.update(
- (expand_key, dbtype)
- for expand_key in self._expanded_parameters[
- key
- ]
- )
+ for idx, paramname in enumerate(
+ self._expanded_parameters[key]
+ )
+ )
else:
dbtype = inputsizes.get(bindparam, None)
- if dbtype is not None:
- if translate:
- # TODO: this part won't work w/ the
- # expanded_parameters feature, e.g. for cx_oracle
- # quoted bound names
- key = translate.get(key, key)
- if not self.dialect.supports_unicode_binds:
- key = self.dialect._encoder(key)[0]
- keyword_inputsizes[key] = dbtype
- try:
- self.cursor.setinputsizes(**keyword_inputsizes)
- except BaseException as e:
- self.root_connection._handle_dbapi_exception(
- e, None, None, None, self
- )
+ generic_inputsizes.extend(
+ (paramname, dbtype, bindparam.type)
+ for paramname in self._expanded_parameters[key]
+ )
+ else:
+ dbtype = inputsizes.get(bindparam, None)
+ generic_inputsizes.append((key, dbtype, bindparam.type))
+ try:
+ self.dialect.do_set_input_sizes(
+ self.cursor, generic_inputsizes, self
+ )
+ except BaseException as e:
+ self.root_connection._handle_dbapi_exception(
+ e, None, None, None, self
+ )
def _exec_default(self, column, default, type_):
if default.is_sequence:
diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py
index 9f30a83ce..ccc6c5968 100644
--- a/lib/sqlalchemy/engine/events.py
+++ b/lib/sqlalchemy/engine/events.py
@@ -792,10 +792,9 @@ class DialectEvents(event.Events):
or a dictionary of string parameter keys to DBAPI type objects for
a named bound parameter execution style.
- Most dialects **do not use** this method at all; the only built-in
- dialect which uses this hook is the cx_Oracle dialect. The hook here
- is made available so as to allow customization of how datatypes are set
- up with the cx_Oracle DBAPI.
+ The setinputsizes hook overall is only used for dialects which include
+ the flag ``use_setinputsizes=True``. Dialects which use this
+ include cx_Oracle, pg8000, asyncpg, and pyodbc dialects.
.. versionadded:: 1.2.9
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
index b7bd3627b..a7f71f5e5 100644
--- a/lib/sqlalchemy/engine/interfaces.py
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -569,6 +569,21 @@ class Dialect(object):
raise NotImplementedError()
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ """invoke the cursor.setinputsizes() method with appropriate arguments
+
+ This hook is called if the dialect.use_inputsizes flag is set to True.
+ Parameter data is passed in a list of tuples (paramname, dbtype,
+ sqltype), where ``paramname`` is the key of the parameter in the
+ statement, ``dbtype`` is the DBAPI datatype and ``sqltype`` is the
+ SQLAlchemy type. The order of tuples is in the correct parameter order.
+
+ .. versionadded:: 1.4
+
+
+ """
+ raise NotImplementedError()
+
def create_xid(self):
"""Create a two-phase transaction ID.
diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py
index 58680f356..d1009eca9 100644
--- a/lib/sqlalchemy/event/registry.py
+++ b/lib/sqlalchemy/event/registry.py
@@ -229,6 +229,7 @@ class _EventKey(object):
"No listeners found for event %s / %r / %s "
% (self.target, self.identifier, self.fn)
)
+
dispatch_reg = _key_to_collection.pop(key)
for collection_ref, listener_ref in dispatch_reg.items():
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 2fa9961eb..23cd778d0 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -971,7 +971,7 @@ class SQLCompiler(Compiled):
@util.memoized_instancemethod
def _get_set_input_sizes_lookup(
- self, translate=None, include_types=None, exclude_types=None
+ self, include_types=None, exclude_types=None
):
if not hasattr(self, "bind_names"):
return None
@@ -986,7 +986,7 @@ class SQLCompiler(Compiled):
# for a dialect impl, also subclass Emulated first which overrides
# this behavior in those cases to behave like the default.
- if not include_types and not exclude_types:
+ if include_types is None and exclude_types is None:
def _lookup_type(typ):
dialect_impl = typ._unwrapped_dialect_impl(dialect)
@@ -1001,12 +1001,12 @@ class SQLCompiler(Compiled):
if (
dbtype is not None
and (
- not exclude_types
+ exclude_types is None
or dbtype not in exclude_types
and type(dialect_impl) not in exclude_types
)
and (
- not include_types
+ include_types is None
or dbtype in include_types
or type(dialect_impl) in include_types
)