diff options
Diffstat (limited to 'lib/sqlalchemy')
| -rw-r--r-- | lib/sqlalchemy/connectors/pyodbc.py | 17 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/oracle/cx_oracle.py | 38 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/asyncpg.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/dialects/postgresql/pg8000.py | 18 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/base.py | 3 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/default.py | 132 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/events.py | 7 | ||||
| -rw-r--r-- | lib/sqlalchemy/engine/interfaces.py | 15 | ||||
| -rw-r--r-- | lib/sqlalchemy/event/registry.py | 1 | ||||
| -rw-r--r-- | lib/sqlalchemy/sql/compiler.py | 8 |
10 files changed, 161 insertions, 96 deletions
diff --git a/lib/sqlalchemy/connectors/pyodbc.py b/lib/sqlalchemy/connectors/pyodbc.py index e1a7c99f4..780161304 100644 --- a/lib/sqlalchemy/connectors/pyodbc.py +++ b/lib/sqlalchemy/connectors/pyodbc.py @@ -24,6 +24,8 @@ class PyODBCConnector(Connector): supports_native_decimal = True default_paramstyle = "named" + use_setinputsizes = True + # for non-DSN connections, this *may* be used to # hold the desired driver name pyodbc_driver_name = None @@ -155,6 +157,21 @@ class PyODBCConnector(Connector): version.append(n) return tuple(version) + def do_set_input_sizes(self, cursor, list_of_tuples, context): + # the rules for these types seems a little strange, as you can pass + # non-tuples as well as tuples, however it seems to assume "0" + # for the subsequent values if you don't pass a tuple which fails + # for types such as pyodbc.SQL_WLONGVARCHAR, which is the datatype + # that ticket #5649 is targeting. + cursor.setinputsizes( + [ + (dbtype, None, None) + if not isinstance(dbtype, tuple) + else dbtype + for key, dbtype, sqltype in list_of_tuples + ] + ) + def set_isolation_level(self, connection, level): # adjust for ConnectionFairy being present # allows attribute set e.g. "connection.autocommit = True" diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index d1b69100f..7bde19090 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -687,15 +687,12 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): if self.compiled._quoted_bind_names: self._setup_quoted_bind_names() - self.set_input_sizes( - self.compiled._quoted_bind_names, - include_types=self.dialect._include_setinputsizes, - ) - self._generate_out_parameter_vars() self._generate_cursor_outputtype_handler() + self.include_set_input_sizes = self.dialect._include_setinputsizes + def post_exec(self): if self.compiled and self.out_parameters and self.compiled.returning: # create a fake cursor result from the out parameters. unlike @@ -746,6 +743,8 @@ class OracleDialect_cx_oracle(OracleDialect): supports_unicode_statements = True supports_unicode_binds = True + use_setinputsizes = True + driver = "cx_oracle" colspecs = { @@ -1172,6 +1171,35 @@ class OracleDialect_cx_oracle(OracleDialect): if oci_prepared: self.do_commit(connection.connection) + def do_set_input_sizes(self, cursor, list_of_tuples, context): + if self.positional: + # not usually used, here to support if someone is modifying + # the dialect to use positional style + cursor.setinputsizes( + *[dbtype for key, dbtype, sqltype in list_of_tuples] + ) + else: + collection = ( + (key, dbtype) + for key, dbtype, sqltype in list_of_tuples + if dbtype + ) + if context and context.compiled: + quoted_bind_names = context.compiled._quoted_bind_names + collection = ( + (quoted_bind_names.get(key, key), dbtype) + for key, dbtype in collection + ) + + if not self.supports_unicode_binds: + # oracle 8 only + collection = ( + (self.dialect._encoder(key)[0], dbtype) + for key, dbtype in collection + ) + + cursor.setinputsizes(**{key: dbtype for key, dbtype in collection}) + def do_recover_twophase(self, connection): connection.info.pop("cx_oracle_prepared", None) diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index a4937d0d2..7d679731b 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -245,7 +245,7 @@ class PGExecutionContext_asyncpg(PGExecutionContext): # we have to exclude ENUM because "enum" not really a "type" # we can cast to, it has to be the name of the type itself. # for now we just omit it from casting - self.set_input_sizes(exclude_types={AsyncAdapt_asyncpg_dbapi.ENUM}) + self.exclude_set_input_sizes = {AsyncAdapt_asyncpg_dbapi.ENUM} def create_server_side_cursor(self): return self._dbapi_connection.cursor(server_side=True) @@ -687,6 +687,8 @@ class PGDialect_asyncpg(PGDialect): statement_compiler = PGCompiler_asyncpg preparer = PGIdentifierPreparer_asyncpg + use_setinputsizes = True + use_native_uuid = True colspecs = util.update_copy( @@ -787,6 +789,20 @@ class PGDialect_asyncpg(PGDialect): e, self.dbapi.InterfaceError ) and "connection is closed" in str(e) + def do_set_input_sizes(self, cursor, list_of_tuples, context): + if self.positional: + cursor.setinputsizes( + *[dbtype for key, dbtype, sqltype in list_of_tuples] + ) + else: + cursor.setinputsizes( + **{ + key: dbtype + for key, dbtype, sqltype in list_of_tuples + if dbtype + } + ) + def on_connect(self): super_connect = super(PGDialect_asyncpg, self).on_connect() diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index b2faa4243..439249157 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -234,8 +234,6 @@ class PGExecutionContext_pg8000(PGExecutionContext): if not self.compiled: return - self.set_input_sizes() - class PGCompiler_pg8000(PGCompiler): def visit_mod_binary(self, binary, operator, **kw): @@ -265,6 +263,8 @@ class PGDialect_pg8000(PGDialect): statement_compiler = PGCompiler_pg8000 preparer = PGIdentifierPreparer_pg8000 + use_setinputsizes = True + # reversed as of pg8000 1.16.6. 1.16.5 and lower # are no longer compatible description_encoding = None @@ -407,6 +407,20 @@ class PGDialect_pg8000(PGDialect): cursor.execute("COMMIT") cursor.close() + def do_set_input_sizes(self, cursor, list_of_tuples, context): + if self.positional: + cursor.setinputsizes( + *[dbtype for key, dbtype, sqltype in list_of_tuples] + ) + else: + cursor.setinputsizes( + **{ + key: dbtype + for key, dbtype, sqltype in list_of_tuples + if dbtype + } + ) + def do_begin_twophase(self, connection, xid): connection.connection.tpc_begin((0, xid, "")) diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py index 4fbdec145..9a5518a96 100644 --- a/lib/sqlalchemy/engine/base.py +++ b/lib/sqlalchemy/engine/base.py @@ -1553,6 +1553,9 @@ class Connection(Connectable): context.pre_exec() + if dialect.use_setinputsizes: + context._set_input_sizes() + cursor, statement, parameters = ( context.cursor, context.statement, diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index ff29c3b9d..d63cb4add 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -56,6 +56,7 @@ class DefaultDialect(interfaces.Dialect): supports_alter = True supports_comments = False inline_comments = False + use_setinputsizes = False # the first value we'd get for an autoincrement # column. @@ -782,6 +783,9 @@ class DefaultExecutionContext(interfaces.ExecutionContext): returned_default_rows = None execution_options = util.immutabledict() + include_set_input_sizes = None + exclude_set_input_sizes = None + cursor_fetch_strategy = _cursor._DEFAULT_FETCH cache_stats = None @@ -1477,9 +1481,7 @@ class DefaultExecutionContext(interfaces.ExecutionContext): self.compiled.postfetch ) - def set_input_sizes( - self, translate=None, include_types=None, exclude_types=None - ): + def _set_input_sizes(self): """Given a cursor and ClauseParameters, call the appropriate style of ``setinputsizes()`` on the cursor, using DB-API types from the bind parameter's ``TypeEngine`` objects. @@ -1488,14 +1490,14 @@ class DefaultExecutionContext(interfaces.ExecutionContext): currently cx_oracle. """ - if self.isddl: - return None + if self.isddl or self.is_text: + return inputsizes = self.compiled._get_set_input_sizes_lookup( - translate=translate, - include_types=include_types, - exclude_types=exclude_types, + include_types=self.include_set_input_sizes, + exclude_types=self.exclude_set_input_sizes, ) + if inputsizes is None: return @@ -1506,82 +1508,52 @@ class DefaultExecutionContext(interfaces.ExecutionContext): ) if self.dialect.positional: - positional_inputsizes = [] - for key in self.compiled.positiontup: - bindparam = self.compiled.binds[key] - if bindparam in self.compiled.literal_execute_params: - continue - - if key in self._expanded_parameters: - if bindparam.type._is_tuple_type: - num = len(bindparam.type.types) - dbtypes = inputsizes[bindparam] - positional_inputsizes.extend( - [ - dbtypes[idx % num] - for idx, key in enumerate( - self._expanded_parameters[key] - ) - ] - ) - else: - dbtype = inputsizes.get(bindparam, None) - positional_inputsizes.extend( - dbtype for dbtype in self._expanded_parameters[key] - ) - else: - dbtype = inputsizes[bindparam] - positional_inputsizes.append(dbtype) - try: - self.cursor.setinputsizes(*positional_inputsizes) - except BaseException as e: - self.root_connection._handle_dbapi_exception( - e, None, None, None, self - ) + items = [ + (key, self.compiled.binds[key]) + for key in self.compiled.positiontup + ] else: - keyword_inputsizes = {} - for bindparam, key in self.compiled.bind_names.items(): - if bindparam in self.compiled.literal_execute_params: - continue - - if key in self._expanded_parameters: - if bindparam.type._is_tuple_type: - num = len(bindparam.type.types) - dbtypes = inputsizes[bindparam] - keyword_inputsizes.update( - [ - (key, dbtypes[idx % num]) - for idx, key in enumerate( - self._expanded_parameters[key] - ) - ] + items = [ + (key, bindparam) + for bindparam, key in self.compiled.bind_names.items() + ] + + generic_inputsizes = [] + for key, bindparam in items: + if bindparam in self.compiled.literal_execute_params: + continue + + if key in self._expanded_parameters: + if bindparam.type._is_tuple_type: + num = len(bindparam.type.types) + dbtypes = inputsizes[bindparam] + generic_inputsizes.extend( + ( + paramname, + dbtypes[idx % num], + bindparam.type.types[idx % num], ) - else: - dbtype = inputsizes.get(bindparam, None) - if dbtype is not None: - keyword_inputsizes.update( - (expand_key, dbtype) - for expand_key in self._expanded_parameters[ - key - ] - ) + for idx, paramname in enumerate( + self._expanded_parameters[key] + ) + ) else: dbtype = inputsizes.get(bindparam, None) - if dbtype is not None: - if translate: - # TODO: this part won't work w/ the - # expanded_parameters feature, e.g. for cx_oracle - # quoted bound names - key = translate.get(key, key) - if not self.dialect.supports_unicode_binds: - key = self.dialect._encoder(key)[0] - keyword_inputsizes[key] = dbtype - try: - self.cursor.setinputsizes(**keyword_inputsizes) - except BaseException as e: - self.root_connection._handle_dbapi_exception( - e, None, None, None, self - ) + generic_inputsizes.extend( + (paramname, dbtype, bindparam.type) + for paramname in self._expanded_parameters[key] + ) + else: + dbtype = inputsizes.get(bindparam, None) + generic_inputsizes.append((key, dbtype, bindparam.type)) + try: + self.dialect.do_set_input_sizes( + self.cursor, generic_inputsizes, self + ) + except BaseException as e: + self.root_connection._handle_dbapi_exception( + e, None, None, None, self + ) def _exec_default(self, column, default, type_): if default.is_sequence: diff --git a/lib/sqlalchemy/engine/events.py b/lib/sqlalchemy/engine/events.py index 9f30a83ce..ccc6c5968 100644 --- a/lib/sqlalchemy/engine/events.py +++ b/lib/sqlalchemy/engine/events.py @@ -792,10 +792,9 @@ class DialectEvents(event.Events): or a dictionary of string parameter keys to DBAPI type objects for a named bound parameter execution style. - Most dialects **do not use** this method at all; the only built-in - dialect which uses this hook is the cx_Oracle dialect. The hook here - is made available so as to allow customization of how datatypes are set - up with the cx_Oracle DBAPI. + The setinputsizes hook overall is only used for dialects which include + the flag ``use_setinputsizes=True``. Dialects which use this + include cx_Oracle, pg8000, asyncpg, and pyodbc dialects. .. versionadded:: 1.2.9 diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index b7bd3627b..a7f71f5e5 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -569,6 +569,21 @@ class Dialect(object): raise NotImplementedError() + def do_set_input_sizes(self, cursor, list_of_tuples, context): + """invoke the cursor.setinputsizes() method with appropriate arguments + + This hook is called if the dialect.use_inputsizes flag is set to True. + Parameter data is passed in a list of tuples (paramname, dbtype, + sqltype), where ``paramname`` is the key of the parameter in the + statement, ``dbtype`` is the DBAPI datatype and ``sqltype`` is the + SQLAlchemy type. The order of tuples is in the correct parameter order. + + .. versionadded:: 1.4 + + + """ + raise NotImplementedError() + def create_xid(self): """Create a two-phase transaction ID. diff --git a/lib/sqlalchemy/event/registry.py b/lib/sqlalchemy/event/registry.py index 58680f356..d1009eca9 100644 --- a/lib/sqlalchemy/event/registry.py +++ b/lib/sqlalchemy/event/registry.py @@ -229,6 +229,7 @@ class _EventKey(object): "No listeners found for event %s / %r / %s " % (self.target, self.identifier, self.fn) ) + dispatch_reg = _key_to_collection.pop(key) for collection_ref, listener_ref in dispatch_reg.items(): diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py index 2fa9961eb..23cd778d0 100644 --- a/lib/sqlalchemy/sql/compiler.py +++ b/lib/sqlalchemy/sql/compiler.py @@ -971,7 +971,7 @@ class SQLCompiler(Compiled): @util.memoized_instancemethod def _get_set_input_sizes_lookup( - self, translate=None, include_types=None, exclude_types=None + self, include_types=None, exclude_types=None ): if not hasattr(self, "bind_names"): return None @@ -986,7 +986,7 @@ class SQLCompiler(Compiled): # for a dialect impl, also subclass Emulated first which overrides # this behavior in those cases to behave like the default. - if not include_types and not exclude_types: + if include_types is None and exclude_types is None: def _lookup_type(typ): dialect_impl = typ._unwrapped_dialect_impl(dialect) @@ -1001,12 +1001,12 @@ class SQLCompiler(Compiled): if ( dbtype is not None and ( - not exclude_types + exclude_types is None or dbtype not in exclude_types and type(dialect_impl) not in exclude_types ) and ( - not include_types + include_types is None or dbtype in include_types or type(dialect_impl) in include_types ) |
