diff options
| -rw-r--r-- | doc/build/changelog/unreleased_14/5649.rst | 11 | ||||
| -rw-r--r-- | lib/sqlalchemy/connectors/pyodbc.py | 17 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/oracle/cx_oracle.py | 38 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/asyncpg.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/pg8000.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/base.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 132 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/events.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/interfaces.py | 15 | ||||
| -rw-r--r-- | lib/sqlalchemy/event/registry.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 8 | ||||
| -rw-r--r-- | test/engine/test_execute.py | 221 |
12 files changed, 393 insertions, 96 deletions
diff --git a/doc/build/changelog/unreleased_14/5649.rst b/doc/build/changelog/unreleased_14/5649.rst new file mode 100644 index 000000000..20e69c4c3 --- /dev/null +++ b/doc/build/changelog/unreleased_14/5649.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: bug, engine, pyodbc + :tickets: 5649 + + Reworked the "setinputsizes()" set of dialect hooks to be correctly + extensible for any arbirary DBAPI, by allowing dialects individual hooks + that may invoke cursor.setinputsizes() in the appropriate style for that + DBAPI. In particular this is intended to support pyodbc's style of usage + which is fundamentally different from that of cx_Oracle. Added support + for pyodbc. + diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index e1a7c99f4..780161304 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -24,6 +24,8 @@ class PyODBCConnector(Connector): supports_native_decimal = True default_paramstyle = "named" + use_setinputsizes = True + # for non-DSN connections, this *may* be used to # hold the desired driver name pyodbc_driver_name = None @@ -155,6 +157,21 @@ class PyODBCConnector(Connector): version.append(n) return tuple(version) + def do_set_input_sizes(self, cursor, list_of_tuples, context): + # the rules for these types seems a little strange, as you can pass + # non-tuples as well as tuples, however it seems to assume "0" + # for the subsequent values if you don't pass a tuple which fails + # for types such as pyodbc.SQL_WLONGVARCHAR, which is the datatype + # that ticket #5649 is targeting. + cursor.setinputsizes( + [ + (dbtype, None, None) + if not isinstance(dbtype, tuple) + else dbtype + for key, dbtype, sqltype in list_of_tuples + ] + ) + def set_isolation_level(self, connection, level): # adjust for ConnectionFairy being present # allows attribute set e.g. "connection.autocommit = True" diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index d1b69100f..7bde19090 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -687,15 +687,12 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): if self.compiled._quoted_bind_names: self._setup_quoted_bind_names() - self.set_input_sizes( - self.compiled._quoted_bind_names, - include_types=self.dialect._include_setinputsizes, - ) - self._generate_out_parameter_vars() self._generate_cursor_outputtype_handler() + self.include_set_input_sizes = self.dialect._include_setinputsizes + def post_exec(self): if self.compiled and self.out_parameters and self.compiled.returning: # create a fake cursor result from the out parameters. unlike @@ -746,6 +743,8 @@ class OracleDialect_cx_oracle(OracleDialect): supports_unicode_statements = True supports_unicode_binds = True + use_setinputsizes = True + driver = "cx_oracle" colspecs = { @@ -1172,6 +1171,35 @@ class OracleDialect_cx_oracle(OracleDialect): if oci_prepared: self.do_commit(connection.connection) + def do_set_input_sizes(self, cursor, list_of_tuples, context): + if self.positional: + # not usually used, here to support if someone is modifying + # the dialect to use positional style + cursor.setinputsizes( + *[dbtype for key, dbtype, sqltype in list_of_tuples] + ) + else: + collection = ( + (key, dbtype) + for key, dbtype, sqltype in list_of_tuples + if dbtype + ) + if context and context.compiled: + quoted_bind_names = context.compiled._quoted_bind_names + collection = ( + (quoted_bind_names.get(key, key), dbtype) + for key, dbtype in collection + ) + + if not self.supports_unicode_binds: + # oracle 8 only + collection = ( + (self.dialect._encoder(key)[0], dbtype) + for key, dbtype in collection + ) + + cursor.setinputsizes(**{key: dbtype for key, dbtype in collection}) + def do_recover_twophase(self, connection): connection.info.pop("cx_oracle_prepared", None) diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index a4937d0d2..7d679731b 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -245,7 +245,7 @@ class PGExecutionContext_asyncpg(PGExecutionContext): # we have to exclude ENUM because "enum" not really a "type" # we can cast to, it has to be the name of the type itself. # for now we just omit it from casting - self.set_input_sizes(exclude_types={AsyncAdapt_asyncpg_dbapi.ENUM}) + self.exclude_set_input_sizes = {AsyncAdapt_asyncpg_dbapi.ENUM} def create_server_side_cursor(self): return self._dbapi_connection.cursor(server_side=True) @@ -687,6 +687,8 @@ class PGDialect_asyncpg(PGDialect): statement_compiler = PGCompiler_asyncpg preparer = PGIdentifierPreparer_asyncpg + use_setinputsizes = True + use_native_uuid = True colspecs = util.update_copy( @@ -787,6 +789,20 @@ class PGDialect_asyncpg(PGDialect): e, self.dbapi.InterfaceError ) and "connection is closed" in str(e) + def do_set_input_sizes(self, cursor, list_of_tuples, context): + if self.positional: + cursor.setinputsizes( + *[dbtype for key, dbtype, sqltype in list_of_tuples] + ) + else: + cursor.setinputsizes( + **{ + key: dbtype + for key, dbtype, sqltype in list_of_tuples + if dbtype + } + ) + def on_connect(self): super_connect = super(PGDialect_asyncpg, self).on_connect() diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index b2faa4243..439249157 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -234,8 +234,6 @@ class PGExecutionContext_pg8000(PGExecutionContext): if not self.compiled: return - self.set_input_sizes() - class PGCompiler_pg8000(PGCompiler): def visit_mod_binary(self, binary, operator, **kw): @@ -265,6 +263,8 @@ class PGDialect_pg8000(PGDialect): statement_compiler = PGCompiler_pg8000 preparer = PGIdentifierPreparer_pg8000 + use_setinputsizes = True + # reversed as of pg8000 1.16.6. 1.16.5 and lower # are no longer compatible description_encoding = None @@ -407,6 +407,20 @@ class PGDialect_pg8000(PGDialect): cursor.execute("COMMIT") cursor.close() + def do_set_input_sizes(self, cursor, list_of_tuples, context): + if self.positional: + cursor.setinputsizes( + *[dbtype for key, dbtype, sqltype in list_of_tuples] + ) + else: + cursor.setinputsizes( + **{ + key: dbtype + for key, dbtype, sqltype in list_of_tuples + if dbtype + } + ) + def do_begin_twophase(self, connection, xid): connection.connection.tpc_begin((0, xid, "")) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 4fbdec145..9a5518a96 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1553,6 +1553,9 @@ class Connection(Connectable): context.pre_exec() + if dialect.use_setinputsizes: + context._set_input_sizes() + cursor, statement, parameters = ( context.cursor, context.statement, diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index ff29c3b9d..d63cb4add 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -56,6 +56,7 @@ class DefaultDialect(interfaces.Dialect): supports_alter = True supports_comments = False inline_comments = False + use_setinputsizes = False # the first value we'd get for an autoincrement # column. @@ -782,6 +783,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext): returned_default_rows = None execution_options = util.immutabledict() + include_set_input_sizes = None + exclude_set_input_sizes = None + cursor_fetch_strategy = _cursor._DEFAULT_FETCH cache_stats = None @@ -1477,9 +1481,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.compiled.postfetch ) - def set_input_sizes( - self, translate=None, include_types=None, exclude_types=None - ): + def _set_input_sizes(self): """Given a cursor and ClauseParameters, call the appropriate style of ``setinputsizes()`` on the cursor, using DB-API types from the bind parameter's ``TypeEngine`` objects. @@ -1488,14 +1490,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext): currently cx_oracle. """ - if self.isddl: - return None + if self.isddl or self.is_text: + return inputsizes = self.compiled._get_set_input_sizes_lookup( - translate=translate, - include_types=include_types, - exclude_types=exclude_types, + include_types=self.include_set_input_sizes, + exclude_types=self.exclude_set_input_sizes, ) + if inputsizes is None: return @@ -1506,82 +1508,52 @@ class DefaultExecutionContext(interfaces.ExecutionContext): ) if self.dialect.positional: - positional_inputsizes = [] - for key in self.compiled.positiontup: - bindparam = self.compiled.binds[key] - if bindparam in self.compiled.literal_execute_params: - continue - - if key in self._expanded_parameters: - if bindparam.type._is_tuple_type: - num = len(bindparam.type.types) - dbtypes = inputsizes[bindparam] - positional_inputsizes.extend( - [ - dbtypes[idx % num] - for idx, key in enumerate( - self._expanded_parameters[key] - ) - ] - ) - else: - dbtype = inputsizes.get(bindparam, None) - positional_inputsizes.extend( - dbtype for dbtype in self._expanded_parameters[key] - ) - else: - dbtype = inputsizes[bindparam] - positional_inputsizes.append(dbtype) - try: - self.cursor.setinputsizes(*positional_inputsizes) - except BaseException as e: - self.root_connection._handle_dbapi_exception( - e, None, None, None, self - ) + items = [ + (key, self.compiled.binds[key]) + for key in self.compiled.positiontup + ] else: - keyword_inputsizes = {} - for bindparam, key in self.compiled.bind_names.items(): - if bindparam in self.compiled.literal_execute_params: - continue - - if key in self._expanded_parameters: - if bindparam.type._is_tuple_type: - num = len(bindparam.type.types) - dbtypes = inputsizes[bindparam] - keyword_inputsizes.update( - [ - (key, dbtypes[idx % num]) - for idx, key in enumerate( - self._expanded_parameters[key] - ) - ] + items = [ + (key, bindparam) + for bindparam, key in self.compiled.bind_names.items() + ] + + generic_inputsizes = [] + for key, bindparam in items: + if bindparam in self.compiled.literal_execute_params: + continue + + if key in self._expanded_parameters: + if bindparam.type._is_tuple_type: + num = len(bindparam.type.types) + dbtypes = inputsizes[bindparam] + generic_inputsizes.extend( + ( + paramname, + dbtypes[idx % num], + bindparam.type.types[idx % num], ) - else: - dbtype = inputsizes.get(bindparam, None) - if dbtype is not None: - keyword_inputsizes.update( - (expand_key, dbtype) - for expand_key in self._expanded_parameters[ - key - ] - ) + for idx, paramname in enumerate( + self._expanded_parameters[key] + ) + ) else: dbtype = inputsizes.get(bindparam, None) - if dbtype is not None: - if translate: - # TODO: this part won't work w/ the - # expanded_parameters feature, e.g. for cx_oracle - # quoted bound names - key = translate.get(key, key) - if not self.dialect.supports_unicode_binds: - key = self.dialect._encoder(key)[0] - keyword_inputsizes[key] = dbtype - try: - self.cursor.setinputsizes(**keyword_inputsizes) - except BaseException as e: - self.root_connection._handle_dbapi_exception( - e, None, None, None, self - ) + generic_inputsizes.extend( + (paramname, dbtype, bindparam.type) + for paramname in self._expanded_parameters[key] + ) + else: + dbtype = inputsizes.get(bindparam, None) + generic_inputsizes.append((key, dbtype, bindparam.type)) + try: + self.dialect.do_set_input_sizes( + self.cursor, generic_inputsizes, self + ) + except BaseException as e: + self.root_connection._handle_dbapi_exception( + e, None, None, None, self + ) def _exec_default(self, column, default, type_): if default.is_sequence: diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py index 9f30a83ce..ccc6c5968 100644 --- a/lib/sqlalchemy/engine/events.py +++ b/lib/sqlalchemy/engine/events.py @@ -792,10 +792,9 @@ class DialectEvents(event.Events): or a dictionary of string parameter keys to DBAPI type objects for a named bound parameter execution style. - Most dialects **do not use** this method at all; the only built-in - dialect which uses this hook is the cx_Oracle dialect. The hook here - is made available so as to allow customization of how datatypes are set - up with the cx_Oracle DBAPI. + The setinputsizes hook overall is only used for dialects which include + the flag ``use_setinputsizes=True``. Dialects which use this + include cx_Oracle, pg8000, asyncpg, and pyodbc dialects. .. versionadded:: 1.2.9 diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index b7bd3627b..a7f71f5e5 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -569,6 +569,21 @@ class Dialect(object): raise NotImplementedError() + def do_set_input_sizes(self, cursor, list_of_tuples, context): + """invoke the cursor.setinputsizes() method with appropriate arguments + + This hook is called if the dialect.use_inputsizes flag is set to True. + Parameter data is passed in a list of tuples (paramname, dbtype, + sqltype), where ``paramname`` is the key of the parameter in the + statement, ``dbtype`` is the DBAPI datatype and ``sqltype`` is the + SQLAlchemy type. The order of tuples is in the correct parameter order. + + .. versionadded:: 1.4 + + + """ + raise NotImplementedError() + def create_xid(self): """Create a two-phase transaction ID. diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py index 58680f356..d1009eca9 100644 --- a/lib/sqlalchemy/event/registry.py +++ b/lib/sqlalchemy/event/registry.py @@ -229,6 +229,7 @@ class _EventKey(object): "No listeners found for event %s / %r / %s " % (self.target, self.identifier, self.fn) ) + dispatch_reg = _key_to_collection.pop(key) for collection_ref, listener_ref in dispatch_reg.items(): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 2fa9961eb..23cd778d0 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -971,7 +971,7 @@ class SQLCompiler(Compiled): @util.memoized_instancemethod def _get_set_input_sizes_lookup( - self, translate=None, include_types=None, exclude_types=None + self, include_types=None, exclude_types=None ): if not hasattr(self, "bind_names"): return None @@ -986,7 +986,7 @@ class SQLCompiler(Compiled): # for a dialect impl, also subclass Emulated first which overrides # this behavior in those cases to behave like the default. - if not include_types and not exclude_types: + if include_types is None and exclude_types is None: def _lookup_type(typ): dialect_impl = typ._unwrapped_dialect_impl(dialect) @@ -1001,12 +1001,12 @@ class SQLCompiler(Compiled): if ( dbtype is not None and ( - not exclude_types + exclude_types is None or dbtype not in exclude_types and type(dialect_impl) not in exclude_types ) and ( - not include_types + include_types is None or dbtype in include_types or type(dialect_impl) in include_types ) diff --git a/test/engine/test_execute.py b/test/engine/test_execute.py index a1d6d2725..2ca6bdd7c 100644 --- a/test/engine/test_execute.py +++ b/test/engine/test_execute.py @@ -3300,3 +3300,224 @@ class FutureExecuteTest(fixtures.FutureEngineMixin, fixtures.TablesTest): "'branching' of new connections.", connection.connect, ) + + +class SetInputSizesTest(fixtures.TablesTest): + __backend__ = True + + __requires__ = ("independent_connections",) + + @classmethod + def define_tables(cls, metadata): + Table( + "users", + metadata, + Column("user_id", INT, primary_key=True, autoincrement=False), + Column("user_name", VARCHAR(20)), + ) + + @testing.fixture + def input_sizes_fixture(self): + canary = mock.Mock() + + def do_set_input_sizes(cursor, list_of_tuples, context): + if not engine.dialect.positional: + # sort by "user_id", "user_name", or otherwise + # param name for a non-positional dialect, so that we can + # confirm the ordering. mostly a py2 thing probably can't + # occur on py3.6+ since we are passing dictionaries with + # "user_id", "user_name" + list_of_tuples = sorted( + list_of_tuples, key=lambda elem: elem[0] + ) + canary.do_set_input_sizes(cursor, list_of_tuples, context) + + def pre_exec(self): + self.translate_set_input_sizes = None + self.include_set_input_sizes = None + self.exclude_set_input_sizes = None + + engine = testing_engine() + engine.connect().close() + + # the idea of this test is we fully replace the dialect + # do_set_input_sizes with a mock, and we can then intercept + # the setting passed to the dialect. the test table uses very + # "safe" datatypes so that the DBAPI does not actually need + # setinputsizes() called in order to work. + + with mock.patch.object( + engine.dialect, "use_setinputsizes", True + ), mock.patch.object( + engine.dialect, "do_set_input_sizes", do_set_input_sizes + ), mock.patch.object( + engine.dialect.execution_ctx_cls, "pre_exec", pre_exec + ): + yield engine, canary + + def test_set_input_sizes_no_event(self, input_sizes_fixture): + engine, canary = input_sizes_fixture + + with engine.connect() as conn: + conn.execute( + self.tables.users.insert(), + [ + {"user_id": 1, "user_name": "n1"}, + {"user_id": 2, "user_name": "n2"}, + ], + ) + + eq_( + canary.mock_calls, + [ + call.do_set_input_sizes( + mock.ANY, + [ + ( + "user_id", + mock.ANY, + testing.eq_type_affinity(Integer), + ), + ( + "user_name", + mock.ANY, + testing.eq_type_affinity(String), + ), + ], + mock.ANY, + ) + ], + ) + + def test_set_input_sizes_expanding_param(self, input_sizes_fixture): + engine, canary = input_sizes_fixture + + with engine.connect() as conn: + conn.execute( + select(self.tables.users).where( + self.tables.users.c.user_name.in_(["x", "y", "z"]) + ) + ) + + eq_( + canary.mock_calls, + [ + call.do_set_input_sizes( + mock.ANY, + [ + ( + "user_name_1_1", + mock.ANY, + testing.eq_type_affinity(String), + ), + ( + "user_name_1_2", + mock.ANY, + testing.eq_type_affinity(String), + ), + ( + "user_name_1_3", + mock.ANY, + testing.eq_type_affinity(String), + ), + ], + mock.ANY, + ) + ], + ) + + @testing.requires.tuple_in + def test_set_input_sizes_expanding_tuple_param(self, input_sizes_fixture): + engine, canary = input_sizes_fixture + + from sqlalchemy import tuple_ + + with engine.connect() as conn: + conn.execute( + select(self.tables.users).where( + tuple_( + self.tables.users.c.user_id, + self.tables.users.c.user_name, + ).in_([(1, "x"), (2, "y")]) + ) + ) + + eq_( + canary.mock_calls, + [ + call.do_set_input_sizes( + mock.ANY, + [ + ( + "param_1_1_1", + mock.ANY, + testing.eq_type_affinity(Integer), + ), + ( + "param_1_1_2", + mock.ANY, + testing.eq_type_affinity(String), + ), + ( + "param_1_2_1", + mock.ANY, + testing.eq_type_affinity(Integer), + ), + ( + "param_1_2_2", + mock.ANY, + testing.eq_type_affinity(String), + ), + ], + mock.ANY, + ) + ], + ) + + def test_set_input_sizes_event(self, input_sizes_fixture): + engine, canary = input_sizes_fixture + + SPECIAL_STRING = mock.Mock() + + @event.listens_for(engine, "do_setinputsizes") + def do_setinputsizes( + inputsizes, cursor, statement, parameters, context + ): + for k in inputsizes: + if k.type._type_affinity is String: + inputsizes[k] = ( + SPECIAL_STRING, + None, + 0, + ) + + with engine.connect() as conn: + conn.execute( + self.tables.users.insert(), + [ + {"user_id": 1, "user_name": "n1"}, + {"user_id": 2, "user_name": "n2"}, + ], + ) + + eq_( + canary.mock_calls, + [ + call.do_set_input_sizes( + mock.ANY, + [ + ( + "user_id", + mock.ANY, + testing.eq_type_affinity(Integer), + ), + ( + "user_name", + (SPECIAL_STRING, None, 0), + testing.eq_type_affinity(String), + ), + ], + mock.ANY, + ) + ], + ) |
