summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/connectors/pyodbc.py17
-rw-r--r--lib/sqlalchemy/dialects/oracle/cx_oracle.py38
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py18
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pg8000.py18
-rw-r--r--lib/sqlalchemy/engine/base.py3
-rw-r--r--lib/sqlalchemy/engine/default.py132
-rw-r--r--lib/sqlalchemy/engine/events.py7
-rw-r--r--lib/sqlalchemy/engine/interfaces.py15
-rw-r--r--lib/sqlalchemy/event/registry.py1
-rw-r--r--lib/sqlalchemy/sql/compiler.py8
10 files changed, 161 insertions, 96 deletions
diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py
index e1a7c99f4..780161304 100644
--- a/lib/sqlalchemy/connectors/pyodbc.py
+++ b/lib/sqlalchemy/connectors/pyodbc.py
@@ -24,6 +24,8 @@ class PyODBCConnector(Connector):
supports_native_decimal = True
default_paramstyle = "named"
+ use_setinputsizes = True
+
# for non-DSN connections, this *may* be used to
# hold the desired driver name
pyodbc_driver_name = None
@@ -155,6 +157,21 @@ class PyODBCConnector(Connector):
version.append(n)
return tuple(version)
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ # the rules for these types seems a little strange, as you can pass
+ # non-tuples as well as tuples, however it seems to assume "0"
+ # for the subsequent values if you don't pass a tuple which fails
+ # for types such as pyodbc.SQL_WLONGVARCHAR, which is the datatype
+ # that ticket #5649 is targeting.
+ cursor.setinputsizes(
+ [
+ (dbtype, None, None)
+ if not isinstance(dbtype, tuple)
+ else dbtype
+ for key, dbtype, sqltype in list_of_tuples
+ ]
+ )
+
def set_isolation_level(self, connection, level):
# adjust for ConnectionFairy being present
# allows attribute set e.g. "connection.autocommit = True"
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
index d1b69100f..7bde19090 100644
--- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py
+++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -687,15 +687,12 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
if self.compiled._quoted_bind_names:
self._setup_quoted_bind_names()
- self.set_input_sizes(
- self.compiled._quoted_bind_names,
- include_types=self.dialect._include_setinputsizes,
- )
-
self._generate_out_parameter_vars()
self._generate_cursor_outputtype_handler()
+ self.include_set_input_sizes = self.dialect._include_setinputsizes
+
def post_exec(self):
if self.compiled and self.out_parameters and self.compiled.returning:
# create a fake cursor result from the out parameters. unlike
@@ -746,6 +743,8 @@ class OracleDialect_cx_oracle(OracleDialect):
supports_unicode_statements = True
supports_unicode_binds = True
+ use_setinputsizes = True
+
driver = "cx_oracle"
colspecs = {
@@ -1172,6 +1171,35 @@ class OracleDialect_cx_oracle(OracleDialect):
if oci_prepared:
self.do_commit(connection.connection)
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ if self.positional:
+ # not usually used, here to support if someone is modifying
+ # the dialect to use positional style
+ cursor.setinputsizes(
+ *[dbtype for key, dbtype, sqltype in list_of_tuples]
+ )
+ else:
+ collection = (
+ (key, dbtype)
+ for key, dbtype, sqltype in list_of_tuples
+ if dbtype
+ )
+ if context and context.compiled:
+ quoted_bind_names = context.compiled._quoted_bind_names
+ collection = (
+ (quoted_bind_names.get(key, key), dbtype)
+ for key, dbtype in collection
+ )
+
+ if not self.supports_unicode_binds:
+ # oracle 8 only
+ collection = (
+ (self.dialect._encoder(key)[0], dbtype)
+ for key, dbtype in collection
+ )
+
+ cursor.setinputsizes(**{key: dbtype for key, dbtype in collection})
+
def do_recover_twophase(self, connection):
connection.info.pop("cx_oracle_prepared", None)
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
index a4937d0d2..7d679731b 100644
--- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -245,7 +245,7 @@ class PGExecutionContext_asyncpg(PGExecutionContext):
# we have to exclude ENUM because "enum" not really a "type"
# we can cast to, it has to be the name of the type itself.
# for now we just omit it from casting
- self.set_input_sizes(exclude_types={AsyncAdapt_asyncpg_dbapi.ENUM})
+ self.exclude_set_input_sizes = {AsyncAdapt_asyncpg_dbapi.ENUM}
def create_server_side_cursor(self):
return self._dbapi_connection.cursor(server_side=True)
@@ -687,6 +687,8 @@ class PGDialect_asyncpg(PGDialect):
statement_compiler = PGCompiler_asyncpg
preparer = PGIdentifierPreparer_asyncpg
+ use_setinputsizes = True
+
use_native_uuid = True
colspecs = util.update_copy(
@@ -787,6 +789,20 @@ class PGDialect_asyncpg(PGDialect):
e, self.dbapi.InterfaceError
) and "connection is closed" in str(e)
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ if self.positional:
+ cursor.setinputsizes(
+ *[dbtype for key, dbtype, sqltype in list_of_tuples]
+ )
+ else:
+ cursor.setinputsizes(
+ **{
+ key: dbtype
+ for key, dbtype, sqltype in list_of_tuples
+ if dbtype
+ }
+ )
+
def on_connect(self):
super_connect = super(PGDialect_asyncpg, self).on_connect()
diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py
index b2faa4243..439249157 100644
--- a/lib/sqlalchemy/dialects/postgresql/pg8000.py
+++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py
@@ -234,8 +234,6 @@ class PGExecutionContext_pg8000(PGExecutionContext):
if not self.compiled:
return
- self.set_input_sizes()
-
class PGCompiler_pg8000(PGCompiler):
def visit_mod_binary(self, binary, operator, **kw):
@@ -265,6 +263,8 @@ class PGDialect_pg8000(PGDialect):
statement_compiler = PGCompiler_pg8000
preparer = PGIdentifierPreparer_pg8000
+ use_setinputsizes = True
+
# reversed as of pg8000 1.16.6. 1.16.5 and lower
# are no longer compatible
description_encoding = None
@@ -407,6 +407,20 @@ class PGDialect_pg8000(PGDialect):
cursor.execute("COMMIT")
cursor.close()
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ if self.positional:
+ cursor.setinputsizes(
+ *[dbtype for key, dbtype, sqltype in list_of_tuples]
+ )
+ else:
+ cursor.setinputsizes(
+ **{
+ key: dbtype
+ for key, dbtype, sqltype in list_of_tuples
+ if dbtype
+ }
+ )
+
def do_begin_twophase(self, connection, xid):
connection.connection.tpc_begin((0, xid, ""))
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index 4fbdec145..9a5518a96 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -1553,6 +1553,9 @@ class Connection(Connectable):
context.pre_exec()
+ if dialect.use_setinputsizes:
+ context._set_input_sizes()
+
cursor, statement, parameters = (
context.cursor,
context.statement,
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index ff29c3b9d..d63cb4add 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -56,6 +56,7 @@ class DefaultDialect(interfaces.Dialect):
supports_alter = True
supports_comments = False
inline_comments = False
+ use_setinputsizes = False
# the first value we'd get for an autoincrement
# column.
@@ -782,6 +783,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
returned_default_rows = None
execution_options = util.immutabledict()
+ include_set_input_sizes = None
+ exclude_set_input_sizes = None
+
cursor_fetch_strategy = _cursor._DEFAULT_FETCH
cache_stats = None
@@ -1477,9 +1481,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.compiled.postfetch
)
- def set_input_sizes(
- self, translate=None, include_types=None, exclude_types=None
- ):
+ def _set_input_sizes(self):
"""Given a cursor and ClauseParameters, call the appropriate
style of ``setinputsizes()`` on the cursor, using DB-API types
from the bind parameter's ``TypeEngine`` objects.
@@ -1488,14 +1490,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
currently cx_oracle.
"""
- if self.isddl:
- return None
+ if self.isddl or self.is_text:
+ return
inputsizes = self.compiled._get_set_input_sizes_lookup(
- translate=translate,
- include_types=include_types,
- exclude_types=exclude_types,
+ include_types=self.include_set_input_sizes,
+ exclude_types=self.exclude_set_input_sizes,
)
+
if inputsizes is None:
return
@@ -1506,82 +1508,52 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
)
if self.dialect.positional:
- positional_inputsizes = []
- for key in self.compiled.positiontup:
- bindparam = self.compiled.binds[key]
- if bindparam in self.compiled.literal_execute_params:
- continue
-
- if key in self._expanded_parameters:
- if bindparam.type._is_tuple_type:
- num = len(bindparam.type.types)
- dbtypes = inputsizes[bindparam]
- positional_inputsizes.extend(
- [
- dbtypes[idx % num]
- for idx, key in enumerate(
- self._expanded_parameters[key]
- )
- ]
- )
- else:
- dbtype = inputsizes.get(bindparam, None)
- positional_inputsizes.extend(
- dbtype for dbtype in self._expanded_parameters[key]
- )
- else:
- dbtype = inputsizes[bindparam]
- positional_inputsizes.append(dbtype)
- try:
- self.cursor.setinputsizes(*positional_inputsizes)
- except BaseException as e:
- self.root_connection._handle_dbapi_exception(
- e, None, None, None, self
- )
+ items = [
+ (key, self.compiled.binds[key])
+ for key in self.compiled.positiontup
+ ]
else:
- keyword_inputsizes = {}
- for bindparam, key in self.compiled.bind_names.items():
- if bindparam in self.compiled.literal_execute_params:
- continue
-
- if key in self._expanded_parameters:
- if bindparam.type._is_tuple_type:
- num = len(bindparam.type.types)
- dbtypes = inputsizes[bindparam]
- keyword_inputsizes.update(
- [
- (key, dbtypes[idx % num])
- for idx, key in enumerate(
- self._expanded_parameters[key]
- )
- ]
+ items = [
+ (key, bindparam)
+ for bindparam, key in self.compiled.bind_names.items()
+ ]
+
+ generic_inputsizes = []
+ for key, bindparam in items:
+ if bindparam in self.compiled.literal_execute_params:
+ continue
+
+ if key in self._expanded_parameters:
+ if bindparam.type._is_tuple_type:
+ num = len(bindparam.type.types)
+ dbtypes = inputsizes[bindparam]
+ generic_inputsizes.extend(
+ (
+ paramname,
+ dbtypes[idx % num],
+ bindparam.type.types[idx % num],
)
- else:
- dbtype = inputsizes.get(bindparam, None)
- if dbtype is not None:
- keyword_inputsizes.update(
- (expand_key, dbtype)
- for expand_key in self._expanded_parameters[
- key
- ]
- )
+ for idx, paramname in enumerate(
+ self._expanded_parameters[key]
+ )
+ )
else:
dbtype = inputsizes.get(bindparam, None)
- if dbtype is not None:
- if translate:
- # TODO: this part won't work w/ the
- # expanded_parameters feature, e.g. for cx_oracle
- # quoted bound names
- key = translate.get(key, key)
- if not self.dialect.supports_unicode_binds:
- key = self.dialect._encoder(key)[0]
- keyword_inputsizes[key] = dbtype
- try:
- self.cursor.setinputsizes(**keyword_inputsizes)
- except BaseException as e:
- self.root_connection._handle_dbapi_exception(
- e, None, None, None, self
- )
+ generic_inputsizes.extend(
+ (paramname, dbtype, bindparam.type)
+ for paramname in self._expanded_parameters[key]
+ )
+ else:
+ dbtype = inputsizes.get(bindparam, None)
+ generic_inputsizes.append((key, dbtype, bindparam.type))
+ try:
+ self.dialect.do_set_input_sizes(
+ self.cursor, generic_inputsizes, self
+ )
+ except BaseException as e:
+ self.root_connection._handle_dbapi_exception(
+ e, None, None, None, self
+ )
def _exec_default(self, column, default, type_):
if default.is_sequence:
diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py
index 9f30a83ce..ccc6c5968 100644
--- a/lib/sqlalchemy/engine/events.py
+++ b/lib/sqlalchemy/engine/events.py
@@ -792,10 +792,9 @@ class DialectEvents(event.Events):
or a dictionary of string parameter keys to DBAPI type objects for
a named bound parameter execution style.
- Most dialects **do not use** this method at all; the only built-in
- dialect which uses this hook is the cx_Oracle dialect. The hook here
- is made available so as to allow customization of how datatypes are set
- up with the cx_Oracle DBAPI.
+ The setinputsizes hook overall is only used for dialects which include
+ the flag ``use_setinputsizes=True``. Dialects which use this
+ include cx_Oracle, pg8000, asyncpg, and pyodbc dialects.
.. versionadded:: 1.2.9
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
index b7bd3627b..a7f71f5e5 100644
--- a/lib/sqlalchemy/engine/interfaces.py
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -569,6 +569,21 @@ class Dialect(object):
raise NotImplementedError()
+ def do_set_input_sizes(self, cursor, list_of_tuples, context):
+ """invoke the cursor.setinputsizes() method with appropriate arguments
+
+ This hook is called if the dialect.use_inputsizes flag is set to True.
+ Parameter data is passed in a list of tuples (paramname, dbtype,
+ sqltype), where ``paramname`` is the key of the parameter in the
+ statement, ``dbtype`` is the DBAPI datatype and ``sqltype`` is the
+ SQLAlchemy type. The order of tuples is in the correct parameter order.
+
+ .. versionadded:: 1.4
+
+
+ """
+ raise NotImplementedError()
+
def create_xid(self):
"""Create a two-phase transaction ID.
diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py
index 58680f356..d1009eca9 100644
--- a/lib/sqlalchemy/event/registry.py
+++ b/lib/sqlalchemy/event/registry.py
@@ -229,6 +229,7 @@ class _EventKey(object):
"No listeners found for event %s / %r / %s "
% (self.target, self.identifier, self.fn)
)
+
dispatch_reg = _key_to_collection.pop(key)
for collection_ref, listener_ref in dispatch_reg.items():
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 2fa9961eb..23cd778d0 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -971,7 +971,7 @@ class SQLCompiler(Compiled):
@util.memoized_instancemethod
def _get_set_input_sizes_lookup(
- self, translate=None, include_types=None, exclude_types=None
+ self, include_types=None, exclude_types=None
):
if not hasattr(self, "bind_names"):
return None
@@ -986,7 +986,7 @@ class SQLCompiler(Compiled):
# for a dialect impl, also subclass Emulated first which overrides
# this behavior in those cases to behave like the default.
- if not include_types and not exclude_types:
+ if include_types is None and exclude_types is None:
def _lookup_type(typ):
dialect_impl = typ._unwrapped_dialect_impl(dialect)
@@ -1001,12 +1001,12 @@ class SQLCompiler(Compiled):
if (
dbtype is not None
and (
- not exclude_types
+ exclude_types is None
or dbtype not in exclude_types
and type(dialect_impl) not in exclude_types
)
and (
- not include_types
+ include_types is None
or dbtype in include_types
or type(dialect_impl) in include_types
)