diff options
| author | Mike Bayer <mike_mp@zzzcomputing.com> | 2019-01-06 01:14:26 -0500 |
|---|---|---|
| committer | mike bayer <mike_mp@zzzcomputing.com> | 2019-01-06 17:34:50 +0000 |
| commit | 1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch) | |
| tree | 28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/dialects | |
| parent | 404e69426b05a82d905cbb3ad33adafccddb00dd (diff) | |
| download | sqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz | |
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits
applied at all.
The black run will format code consistently, however in
some cases that are prevalent in SQLAlchemy code it produces
too-long lines. The too-long lines will be resolved in the
following commit that will resolve all remaining flake8 issues
including shadowed builtins, long lines, import order, unused
imports, duplicate imports, and docstring issues.
Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/dialects')
56 files changed, 6681 insertions, 4045 deletions
diff --git a/lib/sqlalchemy/dialects/__init__.py b/lib/sqlalchemy/dialects/__init__.py index 963babcb8..65f30bb76 100644 --- a/lib/sqlalchemy/dialects/__init__.py +++ b/lib/sqlalchemy/dialects/__init__.py @@ -6,18 +6,19 @@ # the MIT License: http://www.opensource.org/licenses/mit-license.php __all__ = ( - 'firebird', - 'mssql', - 'mysql', - 'oracle', - 'postgresql', - 'sqlite', - 'sybase', + "firebird", + "mssql", + "mysql", + "oracle", + "postgresql", + "sqlite", + "sybase", ) from .. import util -_translates = {'postgres': 'postgresql'} +_translates = {"postgres": "postgresql"} + def _auto_fn(name): """default dialect importer. @@ -40,7 +41,7 @@ def _auto_fn(name): ) dialect = translated try: - module = __import__('sqlalchemy.dialects.%s' % (dialect, )).dialects + module = __import__("sqlalchemy.dialects.%s" % (dialect,)).dialects except ImportError: return None @@ -51,6 +52,7 @@ def _auto_fn(name): else: return None + registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn) -plugins = util.PluginLoader("sqlalchemy.plugins")
\ No newline at end of file +plugins = util.PluginLoader("sqlalchemy.plugins") diff --git a/lib/sqlalchemy/dialects/firebird/__init__.py b/lib/sqlalchemy/dialects/firebird/__init__.py index c83db453b..510d62337 100644 --- a/lib/sqlalchemy/dialects/firebird/__init__.py +++ b/lib/sqlalchemy/dialects/firebird/__init__.py @@ -7,14 +7,35 @@ from . import base, kinterbasdb, fdb # noqa -from sqlalchemy.dialects.firebird.base import \ - SMALLINT, BIGINT, FLOAT, DATE, TIME, \ - TEXT, NUMERIC, TIMESTAMP, VARCHAR, CHAR, BLOB +from sqlalchemy.dialects.firebird.base import ( + SMALLINT, + BIGINT, + FLOAT, + DATE, + TIME, + TEXT, + NUMERIC, + TIMESTAMP, + VARCHAR, + CHAR, + BLOB, +) base.dialect = dialect = fdb.dialect __all__ = ( - 'SMALLINT', 'BIGINT', 'FLOAT', 'FLOAT', 'DATE', 'TIME', - 'TEXT', 'NUMERIC', 'FLOAT', 'TIMESTAMP', 'VARCHAR', 'CHAR', 'BLOB', - 'dialect' + "SMALLINT", + "BIGINT", + "FLOAT", + "FLOAT", + "DATE", + "TIME", + "TEXT", + "NUMERIC", + "FLOAT", + "TIMESTAMP", + "VARCHAR", + "CHAR", + "BLOB", + "dialect", ) diff --git a/lib/sqlalchemy/dialects/firebird/base.py b/lib/sqlalchemy/dialects/firebird/base.py index 7b470c189..1e9c778f3 100644 --- a/lib/sqlalchemy/dialects/firebird/base.py +++ b/lib/sqlalchemy/dialects/firebird/base.py @@ -79,48 +79,254 @@ from sqlalchemy.engine import base, default, reflection from sqlalchemy.sql import compiler from sqlalchemy.sql.elements import quoted_name -from sqlalchemy.types import (BIGINT, BLOB, DATE, FLOAT, INTEGER, NUMERIC, - SMALLINT, TEXT, TIME, TIMESTAMP, Integer) - - -RESERVED_WORDS = set([ - "active", "add", "admin", "after", "all", "alter", "and", "any", "as", - "asc", "ascending", "at", "auto", "avg", "before", "begin", "between", - "bigint", "bit_length", "blob", "both", "by", "case", "cast", "char", - "character", "character_length", "char_length", "check", "close", - "collate", "column", "commit", "committed", "computed", "conditional", - "connect", "constraint", "containing", "count", "create", "cross", - "cstring", "current", "current_connection", "current_date", - "current_role", "current_time", "current_timestamp", - "current_transaction", "current_user", "cursor", "database", "date", - "day", "dec", "decimal", "declare", "default", "delete", "desc", - "descending", "disconnect", "distinct", "do", "domain", "double", - "drop", "else", "end", "entry_point", "escape", "exception", - "execute", "exists", "exit", "external", "extract", "fetch", "file", - "filter", "float", "for", "foreign", "from", "full", "function", - "gdscode", "generator", "gen_id", "global", "grant", "group", - "having", "hour", "if", "in", "inactive", "index", "inner", - "input_type", "insensitive", "insert", "int", "integer", "into", "is", - "isolation", "join", "key", "leading", "left", "length", "level", - "like", "long", "lower", "manual", "max", "maximum_segment", "merge", - "min", "minute", "module_name", "month", "names", "national", - "natural", "nchar", "no", "not", "null", "numeric", "octet_length", - "of", "on", "only", "open", "option", "or", "order", "outer", - "output_type", "overflow", "page", "pages", "page_size", "parameter", - "password", "plan", "position", "post_event", "precision", "primary", - "privileges", "procedure", "protected", "rdb$db_key", "read", "real", - "record_version", "recreate", "recursive", "references", "release", - "reserv", "reserving", "retain", "returning_values", "returns", - "revoke", "right", "rollback", "rows", "row_count", "savepoint", - "schema", "second", "segment", "select", "sensitive", "set", "shadow", - "shared", "singular", "size", "smallint", "snapshot", "some", "sort", - "sqlcode", "stability", "start", "starting", "starts", "statistics", - "sub_type", "sum", "suspend", "table", "then", "time", "timestamp", - "to", "trailing", "transaction", "trigger", "trim", "uncommitted", - "union", "unique", "update", "upper", "user", "using", "value", - "values", "varchar", "variable", "varying", "view", "wait", "when", - "where", "while", "with", "work", "write", "year", -]) +from sqlalchemy.types import ( + BIGINT, + BLOB, + DATE, + FLOAT, + INTEGER, + NUMERIC, + SMALLINT, + TEXT, + TIME, + TIMESTAMP, + Integer, +) + + +RESERVED_WORDS = set( + [ + "active", + "add", + "admin", + "after", + "all", + "alter", + "and", + "any", + "as", + "asc", + "ascending", + "at", + "auto", + "avg", + "before", + "begin", + "between", + "bigint", + "bit_length", + "blob", + "both", + "by", + "case", + "cast", + "char", + "character", + "character_length", + "char_length", + "check", + "close", + "collate", + "column", + "commit", + "committed", + "computed", + "conditional", + "connect", + "constraint", + "containing", + "count", + "create", + "cross", + "cstring", + "current", + "current_connection", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_transaction", + "current_user", + "cursor", + "database", + "date", + "day", + "dec", + "decimal", + "declare", + "default", + "delete", + "desc", + "descending", + "disconnect", + "distinct", + "do", + "domain", + "double", + "drop", + "else", + "end", + "entry_point", + "escape", + "exception", + "execute", + "exists", + "exit", + "external", + "extract", + "fetch", + "file", + "filter", + "float", + "for", + "foreign", + "from", + "full", + "function", + "gdscode", + "generator", + "gen_id", + "global", + "grant", + "group", + "having", + "hour", + "if", + "in", + "inactive", + "index", + "inner", + "input_type", + "insensitive", + "insert", + "int", + "integer", + "into", + "is", + "isolation", + "join", + "key", + "leading", + "left", + "length", + "level", + "like", + "long", + "lower", + "manual", + "max", + "maximum_segment", + "merge", + "min", + "minute", + "module_name", + "month", + "names", + "national", + "natural", + "nchar", + "no", + "not", + "null", + "numeric", + "octet_length", + "of", + "on", + "only", + "open", + "option", + "or", + "order", + "outer", + "output_type", + "overflow", + "page", + "pages", + "page_size", + "parameter", + "password", + "plan", + "position", + "post_event", + "precision", + "primary", + "privileges", + "procedure", + "protected", + "rdb$db_key", + "read", + "real", + "record_version", + "recreate", + "recursive", + "references", + "release", + "reserv", + "reserving", + "retain", + "returning_values", + "returns", + "revoke", + "right", + "rollback", + "rows", + "row_count", + "savepoint", + "schema", + "second", + "segment", + "select", + "sensitive", + "set", + "shadow", + "shared", + "singular", + "size", + "smallint", + "snapshot", + "some", + "sort", + "sqlcode", + "stability", + "start", + "starting", + "starts", + "statistics", + "sub_type", + "sum", + "suspend", + "table", + "then", + "time", + "timestamp", + "to", + "trailing", + "transaction", + "trigger", + "trim", + "uncommitted", + "union", + "unique", + "update", + "upper", + "user", + "using", + "value", + "values", + "varchar", + "variable", + "varying", + "view", + "wait", + "when", + "where", + "while", + "with", + "work", + "write", + "year", + ] +) class _StringType(sqltypes.String): @@ -133,7 +339,8 @@ class _StringType(sqltypes.String): class VARCHAR(_StringType, sqltypes.VARCHAR): """Firebird VARCHAR type""" - __visit_name__ = 'VARCHAR' + + __visit_name__ = "VARCHAR" def __init__(self, length=None, **kwargs): super(VARCHAR, self).__init__(length=length, **kwargs) @@ -141,7 +348,8 @@ class VARCHAR(_StringType, sqltypes.VARCHAR): class CHAR(_StringType, sqltypes.CHAR): """Firebird CHAR type""" - __visit_name__ = 'CHAR' + + __visit_name__ = "CHAR" def __init__(self, length=None, **kwargs): super(CHAR, self).__init__(length=length, **kwargs) @@ -154,32 +362,33 @@ class _FBDateTime(sqltypes.DateTime): return datetime.datetime(value.year, value.month, value.day) else: return value + return process -colspecs = { - sqltypes.DateTime: _FBDateTime -} + +colspecs = {sqltypes.DateTime: _FBDateTime} ischema_names = { - 'SHORT': SMALLINT, - 'LONG': INTEGER, - 'QUAD': FLOAT, - 'FLOAT': FLOAT, - 'DATE': DATE, - 'TIME': TIME, - 'TEXT': TEXT, - 'INT64': BIGINT, - 'DOUBLE': FLOAT, - 'TIMESTAMP': TIMESTAMP, - 'VARYING': VARCHAR, - 'CSTRING': CHAR, - 'BLOB': BLOB, + "SHORT": SMALLINT, + "LONG": INTEGER, + "QUAD": FLOAT, + "FLOAT": FLOAT, + "DATE": DATE, + "TIME": TIME, + "TEXT": TEXT, + "INT64": BIGINT, + "DOUBLE": FLOAT, + "TIMESTAMP": TIMESTAMP, + "VARYING": VARCHAR, + "CSTRING": CHAR, + "BLOB": BLOB, } # TODO: date conversion types (should be implemented as _FBDateTime, # _FBDate, etc. as bind/result functionality is required) + class FBTypeCompiler(compiler.GenericTypeCompiler): def visit_boolean(self, type_, **kw): return self.visit_SMALLINT(type_, **kw) @@ -194,11 +403,11 @@ class FBTypeCompiler(compiler.GenericTypeCompiler): return "BLOB SUB_TYPE 0" def _extend_string(self, type_, basic): - charset = getattr(type_, 'charset', None) + charset = getattr(type_, "charset", None) if charset is None: return basic else: - return '%s CHARACTER SET %s' % (basic, charset) + return "%s CHARACTER SET %s" % (basic, charset) def visit_CHAR(self, type_, **kw): basic = super(FBTypeCompiler, self).visit_CHAR(type_, **kw) @@ -207,8 +416,8 @@ class FBTypeCompiler(compiler.GenericTypeCompiler): def visit_VARCHAR(self, type_, **kw): if not type_.length: raise exc.CompileError( - "VARCHAR requires a length on dialect %s" % - self.dialect.name) + "VARCHAR requires a length on dialect %s" % self.dialect.name + ) basic = super(FBTypeCompiler, self).visit_VARCHAR(type_, **kw) return self._extend_string(type_, basic) @@ -228,36 +437,42 @@ class FBCompiler(sql.compiler.SQLCompiler): return "CURRENT_TIMESTAMP" def visit_startswith_op_binary(self, binary, operator, **kw): - return '%s STARTING WITH %s' % ( + return "%s STARTING WITH %s" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) + binary.right._compiler_dispatch(self, **kw), + ) def visit_notstartswith_op_binary(self, binary, operator, **kw): - return '%s NOT STARTING WITH %s' % ( + return "%s NOT STARTING WITH %s" % ( binary.left._compiler_dispatch(self, **kw), - binary.right._compiler_dispatch(self, **kw)) + binary.right._compiler_dispatch(self, **kw), + ) def visit_mod_binary(self, binary, operator, **kw): return "mod(%s, %s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_alias(self, alias, asfrom=False, **kwargs): if self.dialect._version_two: - return super(FBCompiler, self).\ - visit_alias(alias, asfrom=asfrom, **kwargs) + return super(FBCompiler, self).visit_alias( + alias, asfrom=asfrom, **kwargs + ) else: # Override to not use the AS keyword which FB 1.5 does not like if asfrom: - alias_name = isinstance(alias.name, - expression._truncated_label) and \ - self._truncated_identifier("alias", - alias.name) or alias.name - - return self.process( - alias.original, asfrom=asfrom, **kwargs) + \ - " " + \ - self.preparer.format_alias(alias, alias_name) + alias_name = ( + isinstance(alias.name, expression._truncated_label) + and self._truncated_identifier("alias", alias.name) + or alias.name + ) + + return ( + self.process(alias.original, asfrom=asfrom, **kwargs) + + " " + + self.preparer.format_alias(alias, alias_name) + ) else: return self.process(alias.original, **kwargs) @@ -320,7 +535,7 @@ class FBCompiler(sql.compiler.SQLCompiler): for c in expression._select_iterables(returning_cols) ] - return 'RETURNING ' + ', '.join(columns) + return "RETURNING " + ", ".join(columns) class FBDDLCompiler(sql.compiler.DDLCompiler): @@ -333,27 +548,33 @@ class FBDDLCompiler(sql.compiler.DDLCompiler): # http://www.firebirdsql.org/manual/generatorguide-sqlsyntax.html if create.element.start is not None: raise NotImplemented( - "Firebird SEQUENCE doesn't support START WITH") + "Firebird SEQUENCE doesn't support START WITH" + ) if create.element.increment is not None: raise NotImplemented( - "Firebird SEQUENCE doesn't support INCREMENT BY") + "Firebird SEQUENCE doesn't support INCREMENT BY" + ) if self.dialect._version_two: - return "CREATE SEQUENCE %s" % \ - self.preparer.format_sequence(create.element) + return "CREATE SEQUENCE %s" % self.preparer.format_sequence( + create.element + ) else: - return "CREATE GENERATOR %s" % \ - self.preparer.format_sequence(create.element) + return "CREATE GENERATOR %s" % self.preparer.format_sequence( + create.element + ) def visit_drop_sequence(self, drop): """Generate a ``DROP GENERATOR`` statement for the sequence.""" if self.dialect._version_two: - return "DROP SEQUENCE %s" % \ - self.preparer.format_sequence(drop.element) + return "DROP SEQUENCE %s" % self.preparer.format_sequence( + drop.element + ) else: - return "DROP GENERATOR %s" % \ - self.preparer.format_sequence(drop.element) + return "DROP GENERATOR %s" % self.preparer.format_sequence( + drop.element + ) class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): @@ -361,7 +582,8 @@ class FBIdentifierPreparer(sql.compiler.IdentifierPreparer): reserved_words = RESERVED_WORDS illegal_initial_characters = compiler.ILLEGAL_INITIAL_CHARACTERS.union( - ['_']) + ["_"] + ) def __init__(self, dialect): super(FBIdentifierPreparer, self).__init__(dialect, omit_schema=True) @@ -372,16 +594,16 @@ class FBExecutionContext(default.DefaultExecutionContext): """Get the next value from the sequence using ``gen_id()``.""" return self._execute_scalar( - "SELECT gen_id(%s, 1) FROM rdb$database" % - self.dialect.identifier_preparer.format_sequence(seq), - type_ + "SELECT gen_id(%s, 1) FROM rdb$database" + % self.dialect.identifier_preparer.format_sequence(seq), + type_, ) class FBDialect(default.DefaultDialect): """Firebird dialect""" - name = 'firebird' + name = "firebird" max_identifier_length = 31 @@ -413,23 +635,23 @@ class FBDialect(default.DefaultDialect): def initialize(self, connection): super(FBDialect, self).initialize(connection) - self._version_two = ('firebird' in self.server_version_info and - self.server_version_info >= (2, ) - ) or \ - ('interbase' in self.server_version_info and - self.server_version_info >= (6, ) - ) + self._version_two = ( + "firebird" in self.server_version_info + and self.server_version_info >= (2,) + ) or ( + "interbase" in self.server_version_info + and self.server_version_info >= (6,) + ) if not self._version_two: # TODO: whatever other pre < 2.0 stuff goes here self.ischema_names = ischema_names.copy() - self.ischema_names['TIMESTAMP'] = sqltypes.DATE - self.colspecs = { - sqltypes.DateTime: sqltypes.DATE - } + self.ischema_names["TIMESTAMP"] = sqltypes.DATE + self.colspecs = {sqltypes.DateTime: sqltypes.DATE} - self.implicit_returning = self._version_two and \ - self.__dict__.get('implicit_returning', True) + self.implicit_returning = self._version_two and self.__dict__.get( + "implicit_returning", True + ) def normalize_name(self, name): # Remove trailing spaces: FB uses a CHAR() type, @@ -437,8 +659,9 @@ class FBDialect(default.DefaultDialect): name = name and name.rstrip() if name is None: return None - elif name.upper() == name and \ - not self.identifier_preparer._requires_quotes(name.lower()): + elif name.upper() == name and not self.identifier_preparer._requires_quotes( + name.lower() + ): return name.lower() elif name.lower() == name: return quoted_name(name, quote=True) @@ -448,8 +671,9 @@ class FBDialect(default.DefaultDialect): def denormalize_name(self, name): if name is None: return None - elif name.lower() == name and \ - not self.identifier_preparer._requires_quotes(name.lower()): + elif name.lower() == name and not self.identifier_preparer._requires_quotes( + name.lower() + ): return name.upper() else: return name @@ -522,7 +746,7 @@ class FBDialect(default.DefaultDialect): rp = connection.execute(qry, [self.denormalize_name(view_name)]) row = rp.first() if row: - return row['view_source'] + return row["view_source"] else: return None @@ -538,13 +762,13 @@ class FBDialect(default.DefaultDialect): tablename = self.denormalize_name(table_name) # get primary key fields c = connection.execute(keyqry, ["PRIMARY KEY", tablename]) - pkfields = [self.normalize_name(r['fname']) for r in c.fetchall()] - return {'constrained_columns': pkfields, 'name': None} + pkfields = [self.normalize_name(r["fname"]) for r in c.fetchall()] + return {"constrained_columns": pkfields, "name": None} @reflection.cache - def get_column_sequence(self, connection, - table_name, column_name, - schema=None, **kw): + def get_column_sequence( + self, connection, table_name, column_name, schema=None, **kw + ): tablename = self.denormalize_name(table_name) colname = self.denormalize_name(column_name) # Heuristic-query to determine the generator associated to a PK field @@ -567,7 +791,7 @@ class FBDialect(default.DefaultDialect): """ genr = connection.execute(genqry, [tablename, colname]).first() if genr is not None: - return dict(name=self.normalize_name(genr['fgenerator'])) + return dict(name=self.normalize_name(genr["fgenerator"])) @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): @@ -595,7 +819,7 @@ class FBDialect(default.DefaultDialect): """ # get the PK, used to determine the eventual associated sequence pk_constraint = self.get_pk_constraint(connection, table_name) - pkey_cols = pk_constraint['constrained_columns'] + pkey_cols = pk_constraint["constrained_columns"] tablename = self.denormalize_name(table_name) # get all of the fields for this table @@ -605,26 +829,28 @@ class FBDialect(default.DefaultDialect): row = c.fetchone() if row is None: break - name = self.normalize_name(row['fname']) - orig_colname = row['fname'] + name = self.normalize_name(row["fname"]) + orig_colname = row["fname"] # get the data type - colspec = row['ftype'].rstrip() + colspec = row["ftype"].rstrip() coltype = self.ischema_names.get(colspec) if coltype is None: - util.warn("Did not recognize type '%s' of column '%s'" % - (colspec, name)) + util.warn( + "Did not recognize type '%s' of column '%s'" + % (colspec, name) + ) coltype = sqltypes.NULLTYPE - elif issubclass(coltype, Integer) and row['fprec'] != 0: + elif issubclass(coltype, Integer) and row["fprec"] != 0: coltype = NUMERIC( - precision=row['fprec'], - scale=row['fscale'] * -1) - elif colspec in ('VARYING', 'CSTRING'): - coltype = coltype(row['flen']) - elif colspec == 'TEXT': - coltype = TEXT(row['flen']) - elif colspec == 'BLOB': - if row['stype'] == 1: + precision=row["fprec"], scale=row["fscale"] * -1 + ) + elif colspec in ("VARYING", "CSTRING"): + coltype = coltype(row["flen"]) + elif colspec == "TEXT": + coltype = TEXT(row["flen"]) + elif colspec == "BLOB": + if row["stype"] == 1: coltype = TEXT() else: coltype = BLOB() @@ -633,36 +859,36 @@ class FBDialect(default.DefaultDialect): # does it have a default value? defvalue = None - if row['fdefault'] is not None: + if row["fdefault"] is not None: # the value comes down as "DEFAULT 'value'": there may be # more than one whitespace around the "DEFAULT" keyword # and it may also be lower case # (see also http://tracker.firebirdsql.org/browse/CORE-356) - defexpr = row['fdefault'].lstrip() - assert defexpr[:8].rstrip().upper() == \ - 'DEFAULT', "Unrecognized default value: %s" % \ - defexpr + defexpr = row["fdefault"].lstrip() + assert defexpr[:8].rstrip().upper() == "DEFAULT", ( + "Unrecognized default value: %s" % defexpr + ) defvalue = defexpr[8:].strip() - if defvalue == 'NULL': + if defvalue == "NULL": # Redundant defvalue = None col_d = { - 'name': name, - 'type': coltype, - 'nullable': not bool(row['null_flag']), - 'default': defvalue, - 'autoincrement': 'auto', + "name": name, + "type": coltype, + "nullable": not bool(row["null_flag"]), + "default": defvalue, + "autoincrement": "auto", } if orig_colname.lower() == orig_colname: - col_d['quote'] = True + col_d["quote"] = True # if the PK is a single field, try to see if its linked to # a sequence thru a trigger if len(pkey_cols) == 1 and name == pkey_cols[0]: seq_d = self.get_column_sequence(connection, tablename, name) if seq_d is not None: - col_d['sequence'] = seq_d + col_d["sequence"] = seq_d cols.append(col_d) return cols @@ -689,24 +915,26 @@ class FBDialect(default.DefaultDialect): tablename = self.denormalize_name(table_name) c = connection.execute(fkqry, ["FOREIGN KEY", tablename]) - fks = util.defaultdict(lambda: { - 'name': None, - 'constrained_columns': [], - 'referred_schema': None, - 'referred_table': None, - 'referred_columns': [] - }) + fks = util.defaultdict( + lambda: { + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], + } + ) for row in c: - cname = self.normalize_name(row['cname']) + cname = self.normalize_name(row["cname"]) fk = fks[cname] - if not fk['name']: - fk['name'] = cname - fk['referred_table'] = self.normalize_name(row['targetrname']) - fk['constrained_columns'].append( - self.normalize_name(row['fname'])) - fk['referred_columns'].append( - self.normalize_name(row['targetfname'])) + if not fk["name"]: + fk["name"] = cname + fk["referred_table"] = self.normalize_name(row["targetrname"]) + fk["constrained_columns"].append(self.normalize_name(row["fname"])) + fk["referred_columns"].append( + self.normalize_name(row["targetfname"]) + ) return list(fks.values()) @reflection.cache @@ -729,13 +957,14 @@ class FBDialect(default.DefaultDialect): indexes = util.defaultdict(dict) for row in c: - indexrec = indexes[row['index_name']] - if 'name' not in indexrec: - indexrec['name'] = self.normalize_name(row['index_name']) - indexrec['column_names'] = [] - indexrec['unique'] = bool(row['unique_flag']) - - indexrec['column_names'].append( - self.normalize_name(row['field_name'])) + indexrec = indexes[row["index_name"]] + if "name" not in indexrec: + indexrec["name"] = self.normalize_name(row["index_name"]) + indexrec["column_names"] = [] + indexrec["unique"] = bool(row["unique_flag"]) + + indexrec["column_names"].append( + self.normalize_name(row["field_name"]) + ) return list(indexes.values()) diff --git a/lib/sqlalchemy/dialects/firebird/fdb.py b/lib/sqlalchemy/dialects/firebird/fdb.py index e8da6e1b7..5bf3d2c49 100644 --- a/lib/sqlalchemy/dialects/firebird/fdb.py +++ b/lib/sqlalchemy/dialects/firebird/fdb.py @@ -73,25 +73,23 @@ from ... import util class FBDialect_fdb(FBDialect_kinterbasdb): - - def __init__(self, enable_rowcount=True, - retaining=False, **kwargs): + def __init__(self, enable_rowcount=True, retaining=False, **kwargs): super(FBDialect_fdb, self).__init__( - enable_rowcount=enable_rowcount, - retaining=retaining, **kwargs) + enable_rowcount=enable_rowcount, retaining=retaining, **kwargs + ) @classmethod def dbapi(cls): - return __import__('fdb') + return __import__("fdb") def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if opts.get('port'): - opts['host'] = "%s/%s" % (opts['host'], opts['port']) - del opts['port'] + opts = url.translate_connect_args(username="user") + if opts.get("port"): + opts["host"] = "%s/%s" % (opts["host"], opts["port"]) + del opts["port"] opts.update(url.query) - util.coerce_kw_type(opts, 'type_conv', int) + util.coerce_kw_type(opts, "type_conv", int) return ([], opts) @@ -115,4 +113,5 @@ class FBDialect_fdb(FBDialect_kinterbasdb): return self._parse_version_info(version) + dialect = FBDialect_fdb diff --git a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py index dc88fc849..6d7144096 100644 --- a/lib/sqlalchemy/dialects/firebird/kinterbasdb.py +++ b/lib/sqlalchemy/dialects/firebird/kinterbasdb.py @@ -51,6 +51,7 @@ class _kinterbasdb_numeric(object): return str(value) else: return value + return process @@ -65,15 +66,16 @@ class _FBFloat_kinterbasdb(_kinterbasdb_numeric, sqltypes.Float): class FBExecutionContext_kinterbasdb(FBExecutionContext): @property def rowcount(self): - if self.execution_options.get('enable_rowcount', - self.dialect.enable_rowcount): + if self.execution_options.get( + "enable_rowcount", self.dialect.enable_rowcount + ): return self.cursor.rowcount else: return -1 class FBDialect_kinterbasdb(FBDialect): - driver = 'kinterbasdb' + driver = "kinterbasdb" supports_sane_rowcount = False supports_sane_multi_rowcount = False execution_ctx_cls = FBExecutionContext_kinterbasdb @@ -85,13 +87,17 @@ class FBDialect_kinterbasdb(FBDialect): { sqltypes.Numeric: _FBNumeric_kinterbasdb, sqltypes.Float: _FBFloat_kinterbasdb, - } - + }, ) - def __init__(self, type_conv=200, concurrency_level=1, - enable_rowcount=True, - retaining=False, **kwargs): + def __init__( + self, + type_conv=200, + concurrency_level=1, + enable_rowcount=True, + retaining=False, + **kwargs + ): super(FBDialect_kinterbasdb, self).__init__(**kwargs) self.enable_rowcount = enable_rowcount self.type_conv = type_conv @@ -102,7 +108,7 @@ class FBDialect_kinterbasdb(FBDialect): @classmethod def dbapi(cls): - return __import__('kinterbasdb') + return __import__("kinterbasdb") def do_execute(self, cursor, statement, parameters, context=None): # kinterbase does not accept a None, but wants an empty list @@ -116,28 +122,30 @@ class FBDialect_kinterbasdb(FBDialect): dbapi_connection.commit(self.retaining) def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if opts.get('port'): - opts['host'] = "%s/%s" % (opts['host'], opts['port']) - del opts['port'] + opts = url.translate_connect_args(username="user") + if opts.get("port"): + opts["host"] = "%s/%s" % (opts["host"], opts["port"]) + del opts["port"] opts.update(url.query) - util.coerce_kw_type(opts, 'type_conv', int) + util.coerce_kw_type(opts, "type_conv", int) - type_conv = opts.pop('type_conv', self.type_conv) - concurrency_level = opts.pop('concurrency_level', - self.concurrency_level) + type_conv = opts.pop("type_conv", self.type_conv) + concurrency_level = opts.pop( + "concurrency_level", self.concurrency_level + ) if self.dbapi is not None: - initialized = getattr(self.dbapi, 'initialized', None) + initialized = getattr(self.dbapi, "initialized", None) if initialized is None: # CVS rev 1.96 changed the name of the attribute: # http://kinterbasdb.cvs.sourceforge.net/viewvc/kinterbasdb/ # Kinterbasdb-3.0/__init__.py?r1=1.95&r2=1.96 - initialized = getattr(self.dbapi, '_initialized', False) + initialized = getattr(self.dbapi, "_initialized", False) if not initialized: - self.dbapi.init(type_conv=type_conv, - concurrency_level=concurrency_level) + self.dbapi.init( + type_conv=type_conv, concurrency_level=concurrency_level + ) return ([], opts) def _get_server_version_info(self, connection): @@ -160,25 +168,31 @@ class FBDialect_kinterbasdb(FBDialect): def _parse_version_info(self, version): m = match( - r'\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?', version) + r"\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+)( \w+ (\d+)\.(\d+))?", version + ) if not m: raise AssertionError( - "Could not determine version from string '%s'" % version) + "Could not determine version from string '%s'" % version + ) if m.group(5) != None: - return tuple([int(x) for x in m.group(6, 7, 4)] + ['firebird']) + return tuple([int(x) for x in m.group(6, 7, 4)] + ["firebird"]) else: - return tuple([int(x) for x in m.group(1, 2, 3)] + ['interbase']) + return tuple([int(x) for x in m.group(1, 2, 3)] + ["interbase"]) def is_disconnect(self, e, connection, cursor): - if isinstance(e, (self.dbapi.OperationalError, - self.dbapi.ProgrammingError)): + if isinstance( + e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError) + ): msg = str(e) - return ('Unable to complete network request to host' in msg or - 'Invalid connection state' in msg or - 'Invalid cursor state' in msg or - 'connection shutdown' in msg) + return ( + "Unable to complete network request to host" in msg + or "Invalid connection state" in msg + or "Invalid cursor state" in msg + or "connection shutdown" in msg + ) else: return False + dialect = FBDialect_kinterbasdb diff --git a/lib/sqlalchemy/dialects/mssql/__init__.py b/lib/sqlalchemy/dialects/mssql/__init__.py index 9c861e89d..88a94fcfb 100644 --- a/lib/sqlalchemy/dialects/mssql/__init__.py +++ b/lib/sqlalchemy/dialects/mssql/__init__.py @@ -7,20 +7,74 @@ from . import base, pyodbc, adodbapi, pymssql, zxjdbc, mxodbc # noqa -from .base import \ - INTEGER, BIGINT, SMALLINT, TINYINT, VARCHAR, NVARCHAR, CHAR, \ - NCHAR, TEXT, NTEXT, DECIMAL, NUMERIC, FLOAT, DATETIME,\ - DATETIME2, DATETIMEOFFSET, DATE, TIME, SMALLDATETIME, \ - BINARY, VARBINARY, BIT, REAL, IMAGE, TIMESTAMP, ROWVERSION, \ - MONEY, SMALLMONEY, UNIQUEIDENTIFIER, SQL_VARIANT, XML +from .base import ( + INTEGER, + BIGINT, + SMALLINT, + TINYINT, + VARCHAR, + NVARCHAR, + CHAR, + NCHAR, + TEXT, + NTEXT, + DECIMAL, + NUMERIC, + FLOAT, + DATETIME, + DATETIME2, + DATETIMEOFFSET, + DATE, + TIME, + SMALLDATETIME, + BINARY, + VARBINARY, + BIT, + REAL, + IMAGE, + TIMESTAMP, + ROWVERSION, + MONEY, + SMALLMONEY, + UNIQUEIDENTIFIER, + SQL_VARIANT, + XML, +) base.dialect = dialect = pyodbc.dialect __all__ = ( - 'INTEGER', 'BIGINT', 'SMALLINT', 'TINYINT', 'VARCHAR', 'NVARCHAR', 'CHAR', - 'NCHAR', 'TEXT', 'NTEXT', 'DECIMAL', 'NUMERIC', 'FLOAT', 'DATETIME', - 'DATETIME2', 'DATETIMEOFFSET', 'DATE', 'TIME', 'SMALLDATETIME', - 'BINARY', 'VARBINARY', 'BIT', 'REAL', 'IMAGE', 'TIMESTAMP', 'ROWVERSION', - 'MONEY', 'SMALLMONEY', 'UNIQUEIDENTIFIER', 'SQL_VARIANT', 'XML', 'dialect' + "INTEGER", + "BIGINT", + "SMALLINT", + "TINYINT", + "VARCHAR", + "NVARCHAR", + "CHAR", + "NCHAR", + "TEXT", + "NTEXT", + "DECIMAL", + "NUMERIC", + "FLOAT", + "DATETIME", + "DATETIME2", + "DATETIMEOFFSET", + "DATE", + "TIME", + "SMALLDATETIME", + "BINARY", + "VARBINARY", + "BIT", + "REAL", + "IMAGE", + "TIMESTAMP", + "ROWVERSION", + "MONEY", + "SMALLMONEY", + "UNIQUEIDENTIFIER", + "SQL_VARIANT", + "XML", + "dialect", ) diff --git a/lib/sqlalchemy/dialects/mssql/adodbapi.py b/lib/sqlalchemy/dialects/mssql/adodbapi.py index e5bb9ba57..d985c3bb6 100644 --- a/lib/sqlalchemy/dialects/mssql/adodbapi.py +++ b/lib/sqlalchemy/dialects/mssql/adodbapi.py @@ -33,6 +33,7 @@ class MSDateTime_adodbapi(MSDateTime): if type(value) is datetime.date: return datetime.datetime(value.year, value.month, value.day) return value + return process @@ -41,18 +42,16 @@ class MSDialect_adodbapi(MSDialect): supports_sane_multi_rowcount = True supports_unicode = sys.maxunicode == 65535 supports_unicode_statements = True - driver = 'adodbapi' + driver = "adodbapi" @classmethod def import_dbapi(cls): import adodbapi as module + return module colspecs = util.update_copy( - MSDialect.colspecs, - { - sqltypes.DateTime: MSDateTime_adodbapi - } + MSDialect.colspecs, {sqltypes.DateTime: MSDateTime_adodbapi} ) def create_connect_args(self, url): @@ -61,14 +60,13 @@ class MSDialect_adodbapi(MSDialect): token = "'%s'" % token return token - keys = dict( - (k, check_quote(v)) for k, v in url.query.items() - ) + keys = dict((k, check_quote(v)) for k, v in url.query.items()) connectors = ["Provider=SQLOLEDB"] - if 'port' in keys: - connectors.append("Data Source=%s, %s" % - (keys.get("host"), keys.get("port"))) + if "port" in keys: + connectors.append( + "Data Source=%s, %s" % (keys.get("host"), keys.get("port")) + ) else: connectors.append("Data Source=%s" % keys.get("host")) connectors.append("Initial Catalog=%s" % keys.get("database")) @@ -81,7 +79,9 @@ class MSDialect_adodbapi(MSDialect): return [[";".join(connectors)], {}] def is_disconnect(self, e, connection, cursor): - return isinstance(e, self.dbapi.adodbapi.DatabaseError) and \ - "'connection failure'" in str(e) + return isinstance( + e, self.dbapi.adodbapi.DatabaseError + ) and "'connection failure'" in str(e) + dialect = MSDialect_adodbapi diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 9269225d3..161297015 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -655,9 +655,22 @@ from ...sql import compiler, expression, util as sql_util, quoted_name from ... import engine from ...engine import reflection, default from ... import types as sqltypes -from ...types import INTEGER, BIGINT, SMALLINT, DECIMAL, NUMERIC, \ - FLOAT, DATETIME, DATE, BINARY, \ - TEXT, VARCHAR, NVARCHAR, CHAR, NCHAR +from ...types import ( + INTEGER, + BIGINT, + SMALLINT, + DECIMAL, + NUMERIC, + FLOAT, + DATETIME, + DATE, + BINARY, + TEXT, + VARCHAR, + NVARCHAR, + CHAR, + NCHAR, +) from ...util import update_wrapper @@ -672,48 +685,202 @@ MS_2005_VERSION = (9,) MS_2000_VERSION = (8,) RESERVED_WORDS = set( - ['add', 'all', 'alter', 'and', 'any', 'as', 'asc', 'authorization', - 'backup', 'begin', 'between', 'break', 'browse', 'bulk', 'by', 'cascade', - 'case', 'check', 'checkpoint', 'close', 'clustered', 'coalesce', - 'collate', 'column', 'commit', 'compute', 'constraint', 'contains', - 'containstable', 'continue', 'convert', 'create', 'cross', 'current', - 'current_date', 'current_time', 'current_timestamp', 'current_user', - 'cursor', 'database', 'dbcc', 'deallocate', 'declare', 'default', - 'delete', 'deny', 'desc', 'disk', 'distinct', 'distributed', 'double', - 'drop', 'dump', 'else', 'end', 'errlvl', 'escape', 'except', 'exec', - 'execute', 'exists', 'exit', 'external', 'fetch', 'file', 'fillfactor', - 'for', 'foreign', 'freetext', 'freetexttable', 'from', 'full', - 'function', 'goto', 'grant', 'group', 'having', 'holdlock', 'identity', - 'identity_insert', 'identitycol', 'if', 'in', 'index', 'inner', 'insert', - 'intersect', 'into', 'is', 'join', 'key', 'kill', 'left', 'like', - 'lineno', 'load', 'merge', 'national', 'nocheck', 'nonclustered', 'not', - 'null', 'nullif', 'of', 'off', 'offsets', 'on', 'open', 'opendatasource', - 'openquery', 'openrowset', 'openxml', 'option', 'or', 'order', 'outer', - 'over', 'percent', 'pivot', 'plan', 'precision', 'primary', 'print', - 'proc', 'procedure', 'public', 'raiserror', 'read', 'readtext', - 'reconfigure', 'references', 'replication', 'restore', 'restrict', - 'return', 'revert', 'revoke', 'right', 'rollback', 'rowcount', - 'rowguidcol', 'rule', 'save', 'schema', 'securityaudit', 'select', - 'session_user', 'set', 'setuser', 'shutdown', 'some', 'statistics', - 'system_user', 'table', 'tablesample', 'textsize', 'then', 'to', 'top', - 'tran', 'transaction', 'trigger', 'truncate', 'tsequal', 'union', - 'unique', 'unpivot', 'update', 'updatetext', 'use', 'user', 'values', - 'varying', 'view', 'waitfor', 'when', 'where', 'while', 'with', - 'writetext', - ]) + [ + "add", + "all", + "alter", + "and", + "any", + "as", + "asc", + "authorization", + "backup", + "begin", + "between", + "break", + "browse", + "bulk", + "by", + "cascade", + "case", + "check", + "checkpoint", + "close", + "clustered", + "coalesce", + "collate", + "column", + "commit", + "compute", + "constraint", + "contains", + "containstable", + "continue", + "convert", + "create", + "cross", + "current", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "dbcc", + "deallocate", + "declare", + "default", + "delete", + "deny", + "desc", + "disk", + "distinct", + "distributed", + "double", + "drop", + "dump", + "else", + "end", + "errlvl", + "escape", + "except", + "exec", + "execute", + "exists", + "exit", + "external", + "fetch", + "file", + "fillfactor", + "for", + "foreign", + "freetext", + "freetexttable", + "from", + "full", + "function", + "goto", + "grant", + "group", + "having", + "holdlock", + "identity", + "identity_insert", + "identitycol", + "if", + "in", + "index", + "inner", + "insert", + "intersect", + "into", + "is", + "join", + "key", + "kill", + "left", + "like", + "lineno", + "load", + "merge", + "national", + "nocheck", + "nonclustered", + "not", + "null", + "nullif", + "of", + "off", + "offsets", + "on", + "open", + "opendatasource", + "openquery", + "openrowset", + "openxml", + "option", + "or", + "order", + "outer", + "over", + "percent", + "pivot", + "plan", + "precision", + "primary", + "print", + "proc", + "procedure", + "public", + "raiserror", + "read", + "readtext", + "reconfigure", + "references", + "replication", + "restore", + "restrict", + "return", + "revert", + "revoke", + "right", + "rollback", + "rowcount", + "rowguidcol", + "rule", + "save", + "schema", + "securityaudit", + "select", + "session_user", + "set", + "setuser", + "shutdown", + "some", + "statistics", + "system_user", + "table", + "tablesample", + "textsize", + "then", + "to", + "top", + "tran", + "transaction", + "trigger", + "truncate", + "tsequal", + "union", + "unique", + "unpivot", + "update", + "updatetext", + "use", + "user", + "values", + "varying", + "view", + "waitfor", + "when", + "where", + "while", + "with", + "writetext", + ] +) class REAL(sqltypes.REAL): - __visit_name__ = 'REAL' + __visit_name__ = "REAL" def __init__(self, **kw): # REAL is a synonym for FLOAT(24) on SQL server - kw['precision'] = 24 + kw["precision"] = 24 super(REAL, self).__init__(**kw) class TINYINT(sqltypes.Integer): - __visit_name__ = 'TINYINT' + __visit_name__ = "TINYINT" # MSSQL DATE/TIME types have varied behavior, sometimes returning @@ -721,14 +888,15 @@ class TINYINT(sqltypes.Integer): # filter bind parameters into datetime objects (required by pyodbc, # not sure about other dialects). -class _MSDate(sqltypes.Date): +class _MSDate(sqltypes.Date): def bind_processor(self, dialect): def process(value): if type(value) == datetime.date: return datetime.datetime(value.year, value.month, value.day) else: return value + return process _reg = re.compile(r"(\d+)-(\d+)-(\d+)") @@ -741,18 +909,16 @@ class _MSDate(sqltypes.Date): m = self._reg.match(value) if not m: raise ValueError( - "could not parse %r as a date value" % (value, )) - return datetime.date(*[ - int(x or 0) - for x in m.groups() - ]) + "could not parse %r as a date value" % (value,) + ) + return datetime.date(*[int(x or 0) for x in m.groups()]) else: return value + return process class TIME(sqltypes.TIME): - def __init__(self, precision=None, **kwargs): self.precision = precision super(TIME, self).__init__() @@ -763,10 +929,12 @@ class TIME(sqltypes.TIME): def process(value): if isinstance(value, datetime.datetime): value = datetime.datetime.combine( - self.__zero_date, value.time()) + self.__zero_date, value.time() + ) elif isinstance(value, datetime.time): value = datetime.datetime.combine(self.__zero_date, value) return value + return process _reg = re.compile(r"(\d+):(\d+):(\d+)(?:\.(\d{0,6}))?") @@ -779,24 +947,26 @@ class TIME(sqltypes.TIME): m = self._reg.match(value) if not m: raise ValueError( - "could not parse %r as a time value" % (value, )) - return datetime.time(*[ - int(x or 0) - for x in m.groups()]) + "could not parse %r as a time value" % (value,) + ) + return datetime.time(*[int(x or 0) for x in m.groups()]) else: return value + return process + + _MSTime = TIME class _DateTimeBase(object): - def bind_processor(self, dialect): def process(value): if type(value) == datetime.date: return datetime.datetime(value.year, value.month, value.day) else: return value + return process @@ -805,11 +975,11 @@ class _MSDateTime(_DateTimeBase, sqltypes.DateTime): class SMALLDATETIME(_DateTimeBase, sqltypes.DateTime): - __visit_name__ = 'SMALLDATETIME' + __visit_name__ = "SMALLDATETIME" class DATETIME2(_DateTimeBase, sqltypes.DateTime): - __visit_name__ = 'DATETIME2' + __visit_name__ = "DATETIME2" def __init__(self, precision=None, **kw): super(DATETIME2, self).__init__(**kw) @@ -818,7 +988,7 @@ class DATETIME2(_DateTimeBase, sqltypes.DateTime): # TODO: is this not an Interval ? class DATETIMEOFFSET(sqltypes.TypeEngine): - __visit_name__ = 'DATETIMEOFFSET' + __visit_name__ = "DATETIMEOFFSET" def __init__(self, precision=None, **kwargs): self.precision = precision @@ -847,7 +1017,7 @@ class TIMESTAMP(sqltypes._Binary): """ - __visit_name__ = 'TIMESTAMP' + __visit_name__ = "TIMESTAMP" # expected by _Binary to be present length = None @@ -866,12 +1036,14 @@ class TIMESTAMP(sqltypes._Binary): def result_processor(self, dialect, coltype): super_ = super(TIMESTAMP, self).result_processor(dialect, coltype) if self.convert_int: + def process(value): value = super_(value) if value is not None: # https://stackoverflow.com/a/30403242/34549 - value = int(codecs.encode(value, 'hex'), 16) + value = int(codecs.encode(value, "hex"), 16) return value + return process else: return super_ @@ -898,7 +1070,7 @@ class ROWVERSION(TIMESTAMP): """ - __visit_name__ = 'ROWVERSION' + __visit_name__ = "ROWVERSION" class NTEXT(sqltypes.UnicodeText): @@ -906,7 +1078,7 @@ class NTEXT(sqltypes.UnicodeText): """MSSQL NTEXT type, for variable-length unicode text up to 2^30 characters.""" - __visit_name__ = 'NTEXT' + __visit_name__ = "NTEXT" class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary): @@ -925,11 +1097,12 @@ class VARBINARY(sqltypes.VARBINARY, sqltypes.LargeBinary): """ - __visit_name__ = 'VARBINARY' + + __visit_name__ = "VARBINARY" class IMAGE(sqltypes.LargeBinary): - __visit_name__ = 'IMAGE' + __visit_name__ = "IMAGE" class XML(sqltypes.Text): @@ -943,19 +1116,20 @@ class XML(sqltypes.Text): .. versionadded:: 1.1.11 """ - __visit_name__ = 'XML' + + __visit_name__ = "XML" class BIT(sqltypes.TypeEngine): - __visit_name__ = 'BIT' + __visit_name__ = "BIT" class MONEY(sqltypes.TypeEngine): - __visit_name__ = 'MONEY' + __visit_name__ = "MONEY" class SMALLMONEY(sqltypes.TypeEngine): - __visit_name__ = 'SMALLMONEY' + __visit_name__ = "SMALLMONEY" class UNIQUEIDENTIFIER(sqltypes.TypeEngine): @@ -963,7 +1137,8 @@ class UNIQUEIDENTIFIER(sqltypes.TypeEngine): class SQL_VARIANT(sqltypes.TypeEngine): - __visit_name__ = 'SQL_VARIANT' + __visit_name__ = "SQL_VARIANT" + # old names. MSDateTime = _MSDateTime @@ -990,36 +1165,36 @@ MSUniqueIdentifier = UNIQUEIDENTIFIER MSVariant = SQL_VARIANT ischema_names = { - 'int': INTEGER, - 'bigint': BIGINT, - 'smallint': SMALLINT, - 'tinyint': TINYINT, - 'varchar': VARCHAR, - 'nvarchar': NVARCHAR, - 'char': CHAR, - 'nchar': NCHAR, - 'text': TEXT, - 'ntext': NTEXT, - 'decimal': DECIMAL, - 'numeric': NUMERIC, - 'float': FLOAT, - 'datetime': DATETIME, - 'datetime2': DATETIME2, - 'datetimeoffset': DATETIMEOFFSET, - 'date': DATE, - 'time': TIME, - 'smalldatetime': SMALLDATETIME, - 'binary': BINARY, - 'varbinary': VARBINARY, - 'bit': BIT, - 'real': REAL, - 'image': IMAGE, - 'xml': XML, - 'timestamp': TIMESTAMP, - 'money': MONEY, - 'smallmoney': SMALLMONEY, - 'uniqueidentifier': UNIQUEIDENTIFIER, - 'sql_variant': SQL_VARIANT, + "int": INTEGER, + "bigint": BIGINT, + "smallint": SMALLINT, + "tinyint": TINYINT, + "varchar": VARCHAR, + "nvarchar": NVARCHAR, + "char": CHAR, + "nchar": NCHAR, + "text": TEXT, + "ntext": NTEXT, + "decimal": DECIMAL, + "numeric": NUMERIC, + "float": FLOAT, + "datetime": DATETIME, + "datetime2": DATETIME2, + "datetimeoffset": DATETIMEOFFSET, + "date": DATE, + "time": TIME, + "smalldatetime": SMALLDATETIME, + "binary": BINARY, + "varbinary": VARBINARY, + "bit": BIT, + "real": REAL, + "image": IMAGE, + "xml": XML, + "timestamp": TIMESTAMP, + "money": MONEY, + "smallmoney": SMALLMONEY, + "uniqueidentifier": UNIQUEIDENTIFIER, + "sql_variant": SQL_VARIANT, } @@ -1030,8 +1205,8 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): """ - if getattr(type_, 'collation', None): - collation = 'COLLATE %s' % type_.collation + if getattr(type_, "collation", None): + collation = "COLLATE %s" % type_.collation else: collation = None @@ -1041,15 +1216,14 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): if length: spec = spec + "(%s)" % length - return ' '.join([c for c in (spec, collation) - if c is not None]) + return " ".join([c for c in (spec, collation) if c is not None]) def visit_FLOAT(self, type_, **kw): - precision = getattr(type_, 'precision', None) + precision = getattr(type_, "precision", None) if precision is None: return "FLOAT" else: - return "FLOAT(%(precision)s)" % {'precision': precision} + return "FLOAT(%(precision)s)" % {"precision": precision} def visit_TINYINT(self, type_, **kw): return "TINYINT" @@ -1061,7 +1235,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "DATETIMEOFFSET" def visit_TIME(self, type_, **kw): - precision = getattr(type_, 'precision', None) + precision = getattr(type_, "precision", None) if precision is not None: return "TIME(%s)" % precision else: @@ -1074,7 +1248,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "ROWVERSION" def visit_DATETIME2(self, type_, **kw): - precision = getattr(type_, 'precision', None) + precision = getattr(type_, "precision", None) if precision is not None: return "DATETIME2(%s)" % precision else: @@ -1105,7 +1279,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return self._extend("TEXT", type_) def visit_VARCHAR(self, type_, **kw): - return self._extend("VARCHAR", type_, length=type_.length or 'max') + return self._extend("VARCHAR", type_, length=type_.length or "max") def visit_CHAR(self, type_, **kw): return self._extend("CHAR", type_) @@ -1114,7 +1288,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return self._extend("NCHAR", type_) def visit_NVARCHAR(self, type_, **kw): - return self._extend("NVARCHAR", type_, length=type_.length or 'max') + return self._extend("NVARCHAR", type_, length=type_.length or "max") def visit_date(self, type_, **kw): if self.dialect.server_version_info < MS_2008_VERSION: @@ -1141,10 +1315,7 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "XML" def visit_VARBINARY(self, type_, **kw): - return self._extend( - "VARBINARY", - type_, - length=type_.length or 'max') + return self._extend("VARBINARY", type_, length=type_.length or "max") def visit_boolean(self, type_, **kw): return self.visit_BIT(type_) @@ -1156,13 +1327,13 @@ class MSTypeCompiler(compiler.GenericTypeCompiler): return "MONEY" def visit_SMALLMONEY(self, type_, **kw): - return 'SMALLMONEY' + return "SMALLMONEY" def visit_UNIQUEIDENTIFIER(self, type_, **kw): return "UNIQUEIDENTIFIER" def visit_SQL_VARIANT(self, type_, **kw): - return 'SQL_VARIANT' + return "SQL_VARIANT" class MSExecutionContext(default.DefaultExecutionContext): @@ -1186,41 +1357,44 @@ class MSExecutionContext(default.DefaultExecutionContext): insert_has_sequence = seq_column is not None if insert_has_sequence: - self._enable_identity_insert = \ - seq_column.key in self.compiled_parameters[0] or \ - ( - self.compiled.statement.parameters and ( - ( - self.compiled.statement._has_multi_parameters - and - seq_column.key in - self.compiled.statement.parameters[0] - ) or ( - not - self.compiled.statement._has_multi_parameters - and - seq_column.key in - self.compiled.statement.parameters - ) + self._enable_identity_insert = seq_column.key in self.compiled_parameters[ + 0 + ] or ( + self.compiled.statement.parameters + and ( + ( + self.compiled.statement._has_multi_parameters + and seq_column.key + in self.compiled.statement.parameters[0] + ) + or ( + not self.compiled.statement._has_multi_parameters + and seq_column.key + in self.compiled.statement.parameters ) ) + ) else: self._enable_identity_insert = False - self._select_lastrowid = not self.compiled.inline and \ - insert_has_sequence and \ - not self.compiled.returning and \ - not self._enable_identity_insert and \ - not self.executemany + self._select_lastrowid = ( + not self.compiled.inline + and insert_has_sequence + and not self.compiled.returning + and not self._enable_identity_insert + and not self.executemany + ) if self._enable_identity_insert: self.root_connection._cursor_execute( self.cursor, self._opt_encode( - "SET IDENTITY_INSERT %s ON" % - self.dialect.identifier_preparer.format_table(tbl)), + "SET IDENTITY_INSERT %s ON" + % self.dialect.identifier_preparer.format_table(tbl) + ), (), - self) + self, + ) def post_exec(self): """Disable IDENTITY_INSERT if enabled.""" @@ -1230,29 +1404,35 @@ class MSExecutionContext(default.DefaultExecutionContext): if self.dialect.use_scope_identity: conn._cursor_execute( self.cursor, - "SELECT scope_identity() AS lastrowid", (), self) + "SELECT scope_identity() AS lastrowid", + (), + self, + ) else: - conn._cursor_execute(self.cursor, - "SELECT @@identity AS lastrowid", - (), - self) + conn._cursor_execute( + self.cursor, "SELECT @@identity AS lastrowid", (), self + ) # fetchall() ensures the cursor is consumed without closing it row = self.cursor.fetchall()[0] self._lastrowid = int(row[0]) - if (self.isinsert or self.isupdate or self.isdelete) and \ - self.compiled.returning: + if ( + self.isinsert or self.isupdate or self.isdelete + ) and self.compiled.returning: self._result_proxy = engine.FullyBufferedResultProxy(self) if self._enable_identity_insert: conn._cursor_execute( self.cursor, self._opt_encode( - "SET IDENTITY_INSERT %s OFF" % - self.dialect.identifier_preparer. format_table( - self.compiled.statement.table)), + "SET IDENTITY_INSERT %s OFF" + % self.dialect.identifier_preparer.format_table( + self.compiled.statement.table + ) + ), (), - self) + self, + ) def get_lastrowid(self): return self._lastrowid @@ -1262,9 +1442,12 @@ class MSExecutionContext(default.DefaultExecutionContext): try: self.cursor.execute( self._opt_encode( - "SET IDENTITY_INSERT %s OFF" % - self.dialect.identifier_preparer. format_table( - self.compiled.statement.table))) + "SET IDENTITY_INSERT %s OFF" + % self.dialect.identifier_preparer.format_table( + self.compiled.statement.table + ) + ) + ) except Exception: pass @@ -1281,11 +1464,12 @@ class MSSQLCompiler(compiler.SQLCompiler): extract_map = util.update_copy( compiler.SQLCompiler.extract_map, { - 'doy': 'dayofyear', - 'dow': 'weekday', - 'milliseconds': 'millisecond', - 'microseconds': 'microsecond' - }) + "doy": "dayofyear", + "dow": "weekday", + "milliseconds": "millisecond", + "microseconds": "microsecond", + }, + ) def __init__(self, *args, **kwargs): self.tablealiases = {} @@ -1298,6 +1482,7 @@ class MSSQLCompiler(compiler.SQLCompiler): else: super_ = getattr(super(MSSQLCompiler, self), fn.__name__) return super_(*arg, **kw) + return decorate def visit_now_func(self, fn, **kw): @@ -1313,20 +1498,22 @@ class MSSQLCompiler(compiler.SQLCompiler): return "LEN%s" % self.function_argspec(fn, **kw) def visit_concat_op_binary(self, binary, operator, **kw): - return "%s + %s" % \ - (self.process(binary.left, **kw), - self.process(binary.right, **kw)) + return "%s + %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) def visit_true(self, expr, **kw): - return '1' + return "1" def visit_false(self, expr, **kw): - return '0' + return "0" def visit_match_op_binary(self, binary, operator, **kw): return "CONTAINS (%s, %s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def get_select_precolumns(self, select, **kw): """ MS-SQL puts TOP, it's version of LIMIT here """ @@ -1345,7 +1532,8 @@ class MSSQLCompiler(compiler.SQLCompiler): return s else: return compiler.SQLCompiler.get_select_precolumns( - self, select, **kw) + self, select, **kw + ) def get_from_hint_text(self, table, text): return text @@ -1363,20 +1551,21 @@ class MSSQLCompiler(compiler.SQLCompiler): """ if ( - ( - not select._simple_int_limit and - select._limit_clause is not None - ) or ( - select._offset_clause is not None and - not select._simple_int_offset or select._offset + (not select._simple_int_limit and select._limit_clause is not None) + or ( + select._offset_clause is not None + and not select._simple_int_offset + or select._offset ) - ) and not getattr(select, '_mssql_visit', None): + ) and not getattr(select, "_mssql_visit", None): # to use ROW_NUMBER(), an ORDER BY is required. if not select._order_by_clause.clauses: - raise exc.CompileError('MSSQL requires an order_by when ' - 'using an OFFSET or a non-simple ' - 'LIMIT clause') + raise exc.CompileError( + "MSSQL requires an order_by when " + "using an OFFSET or a non-simple " + "LIMIT clause" + ) _order_by_clauses = [ sql_util.unwrap_label_reference(elem) @@ -1385,24 +1574,31 @@ class MSSQLCompiler(compiler.SQLCompiler): limit_clause = select._limit_clause offset_clause = select._offset_clause - kwargs['select_wraps_for'] = select + kwargs["select_wraps_for"] = select select = select._generate() select._mssql_visit = True - select = select.column( - sql.func.ROW_NUMBER().over(order_by=_order_by_clauses) - .label("mssql_rn")).order_by(None).alias() + select = ( + select.column( + sql.func.ROW_NUMBER() + .over(order_by=_order_by_clauses) + .label("mssql_rn") + ) + .order_by(None) + .alias() + ) - mssql_rn = sql.column('mssql_rn') - limitselect = sql.select([c for c in select.c if - c.key != 'mssql_rn']) + mssql_rn = sql.column("mssql_rn") + limitselect = sql.select( + [c for c in select.c if c.key != "mssql_rn"] + ) if offset_clause is not None: limitselect.append_whereclause(mssql_rn > offset_clause) if limit_clause is not None: limitselect.append_whereclause( - mssql_rn <= (limit_clause + offset_clause)) + mssql_rn <= (limit_clause + offset_clause) + ) else: - limitselect.append_whereclause( - mssql_rn <= (limit_clause)) + limitselect.append_whereclause(mssql_rn <= (limit_clause)) return self.process(limitselect, **kwargs) else: return compiler.SQLCompiler.visit_select(self, select, **kwargs) @@ -1422,35 +1618,38 @@ class MSSQLCompiler(compiler.SQLCompiler): @_with_legacy_schema_aliasing def visit_alias(self, alias, **kw): # translate for schema-qualified table aliases - kw['mssql_aliased'] = alias.original + kw["mssql_aliased"] = alias.original return super(MSSQLCompiler, self).visit_alias(alias, **kw) @_with_legacy_schema_aliasing def visit_column(self, column, add_to_result_map=None, **kw): - if column.table is not None and \ - (not self.isupdate and not self.isdelete) or \ - self.is_subquery(): + if ( + column.table is not None + and (not self.isupdate and not self.isdelete) + or self.is_subquery() + ): # translate for schema-qualified table aliases t = self._schema_aliased_table(column.table) if t is not None: converted = expression._corresponding_column_or_error( - t, column) + t, column + ) if add_to_result_map is not None: add_to_result_map( column.name, column.name, (column, column.name, column.key), - column.type + column.type, ) - return super(MSSQLCompiler, self).\ - visit_column(converted, **kw) + return super(MSSQLCompiler, self).visit_column(converted, **kw) return super(MSSQLCompiler, self).visit_column( - column, add_to_result_map=add_to_result_map, **kw) + column, add_to_result_map=add_to_result_map, **kw + ) def _schema_aliased_table(self, table): - if getattr(table, 'schema', None) is not None: + if getattr(table, "schema", None) is not None: if table not in self.tablealiases: self.tablealiases[table] = table.alias() return self.tablealiases[table] @@ -1459,16 +1658,17 @@ class MSSQLCompiler(compiler.SQLCompiler): def visit_extract(self, extract, **kw): field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART(%s, %s)' % \ - (field, self.process(extract.expr, **kw)) + return "DATEPART(%s, %s)" % (field, self.process(extract.expr, **kw)) def visit_savepoint(self, savepoint_stmt): - return "SAVE TRANSACTION %s" % \ - self.preparer.format_savepoint(savepoint_stmt) + return "SAVE TRANSACTION %s" % self.preparer.format_savepoint( + savepoint_stmt + ) def visit_rollback_to_savepoint(self, savepoint_stmt): - return ("ROLLBACK TRANSACTION %s" - % self.preparer.format_savepoint(savepoint_stmt)) + return "ROLLBACK TRANSACTION %s" % self.preparer.format_savepoint( + savepoint_stmt + ) def visit_binary(self, binary, **kwargs): """Move bind parameters to the right-hand side of an operator, where @@ -1481,10 +1681,11 @@ class MSSQLCompiler(compiler.SQLCompiler): and not isinstance(binary.right, expression.BindParameter) ): return self.process( - expression.BinaryExpression(binary.right, - binary.left, - binary.operator), - **kwargs) + expression.BinaryExpression( + binary.right, binary.left, binary.operator + ), + **kwargs + ) return super(MSSQLCompiler, self).visit_binary(binary, **kwargs) def returning_clause(self, stmt, returning_cols): @@ -1497,12 +1698,13 @@ class MSSQLCompiler(compiler.SQLCompiler): adapter = sql_util.ClauseAdapter(target) columns = [ - self._label_select_column(None, adapter.traverse(c), - True, False, {}) + self._label_select_column( + None, adapter.traverse(c), True, False, {} + ) for c in expression._select_iterables(returning_cols) ] - return 'OUTPUT ' + ', '.join(columns) + return "OUTPUT " + ", ".join(columns) def get_cte_preamble(self, recursive): # SQL Server finds it too inconvenient to accept @@ -1515,13 +1717,14 @@ class MSSQLCompiler(compiler.SQLCompiler): if isinstance(column, expression.Function): return column.label(None) else: - return super(MSSQLCompiler, self).\ - label_select_column(select, column, asfrom) + return super(MSSQLCompiler, self).label_select_column( + select, column, asfrom + ) def for_update_clause(self, select): # "FOR UPDATE" is only allowed on "DECLARE CURSOR" which # SQLAlchemy doesn't use - return '' + return "" def order_by_clause(self, select, **kw): order_by = self.process(select._order_by_clause, **kw) @@ -1532,10 +1735,9 @@ class MSSQLCompiler(compiler.SQLCompiler): else: return "" - def update_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the UPDATE..FROM clause specific to MSSQL. In MSSQL, if the UPDATE statement involves an alias of the table to @@ -1543,13 +1745,12 @@ class MSSQLCompiler(compiler.SQLCompiler): well. Otherwise, it is optional. Here, we add it regardless. """ - return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in [from_table] + extra_froms) + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) - def delete_table_clause(self, delete_stmt, from_table, - extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms): """If we have extra froms make sure we render any alias as hint.""" ashint = False if extra_froms: @@ -1558,20 +1759,21 @@ class MSSQLCompiler(compiler.SQLCompiler): self, asfrom=True, iscrud=True, ashint=ashint ) - def delete_extra_from_clause(self, delete_stmt, from_table, - extra_froms, from_hints, **kw): + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the DELETE .. FROM clause specific to MSSQL. Yes, it has the FROM keyword twice. """ - return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in [from_table] + extra_froms) + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) def visit_empty_set_expr(self, type_): - return 'SELECT 1 WHERE 1!=1' + return "SELECT 1 WHERE 1!=1" class MSSQLStrictCompiler(MSSQLCompiler): @@ -1583,20 +1785,21 @@ class MSSQLStrictCompiler(MSSQLCompiler): binds are used. """ + ansi_bind_rules = True def visit_in_op_binary(self, binary, operator, **kw): - kw['literal_binds'] = True + kw["literal_binds"] = True return "%s IN %s" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) def visit_notin_op_binary(self, binary, operator, **kw): - kw['literal_binds'] = True + kw["literal_binds"] = True return "%s NOT IN %s" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) def render_literal_value(self, value, type_): @@ -1615,23 +1818,28 @@ class MSSQLStrictCompiler(MSSQLCompiler): # SQL Server wants single quotes around the date string. return "'" + str(value) + "'" else: - return super(MSSQLStrictCompiler, self).\ - render_literal_value(value, type_) + return super(MSSQLStrictCompiler, self).render_literal_value( + value, type_ + ) class MSDDLCompiler(compiler.DDLCompiler): - def get_column_specification(self, column, **kwargs): colspec = ( - self.preparer.format_column(column) + " " + self.preparer.format_column(column) + + " " + self.dialect.type_compiler.process( - column.type, type_expression=column) + column.type, type_expression=column + ) ) if column.nullable is not None: - if not column.nullable or column.primary_key or \ - isinstance(column.default, sa_schema.Sequence) or \ - column.autoincrement is True: + if ( + not column.nullable + or column.primary_key + or isinstance(column.default, sa_schema.Sequence) + or column.autoincrement is True + ): colspec += " NOT NULL" else: colspec += " NULL" @@ -1639,15 +1847,18 @@ class MSDDLCompiler(compiler.DDLCompiler): if column.table is None: raise exc.CompileError( "mssql requires Table-bound columns " - "in order to generate DDL") + "in order to generate DDL" + ) # install an IDENTITY Sequence if we either a sequence or an implicit # IDENTITY column if isinstance(column.default, sa_schema.Sequence): - if (column.default.start is not None or - column.default.increment is not None or - column is not column.table._autoincrement_column): + if ( + column.default.start is not None + or column.default.increment is not None + or column is not column.table._autoincrement_column + ): util.warn_deprecated( "Use of Sequence with SQL Server in order to affect the " "parameters of the IDENTITY value is deprecated, as " @@ -1655,18 +1866,23 @@ class MSDDLCompiler(compiler.DDLCompiler): "will correspond to an actual SQL Server " "CREATE SEQUENCE in " "a future release. Please use the mssql_identity_start " - "and mssql_identity_increment parameters.") + "and mssql_identity_increment parameters." + ) if column.default.start == 0: start = 0 else: start = column.default.start or 1 - colspec += " IDENTITY(%s,%s)" % (start, - column.default.increment or 1) - elif column is column.table._autoincrement_column or \ - column.autoincrement is True: - start = column.dialect_options['mssql']['identity_start'] - increment = column.dialect_options['mssql']['identity_increment'] + colspec += " IDENTITY(%s,%s)" % ( + start, + column.default.increment or 1, + ) + elif ( + column is column.table._autoincrement_column + or column.autoincrement is True + ): + start = column.dialect_options["mssql"]["identity_start"] + increment = column.dialect_options["mssql"]["identity_increment"] colspec += " IDENTITY(%s,%s)" % (start, increment) else: default = self.get_column_default_string(column) @@ -1684,84 +1900,88 @@ class MSDDLCompiler(compiler.DDLCompiler): text += "UNIQUE " # handle clustering option - clustered = index.dialect_options['mssql']['clustered'] + clustered = index.dialect_options["mssql"]["clustered"] if clustered is not None: if clustered: text += "CLUSTERED " else: text += "NONCLUSTERED " - text += "INDEX %s ON %s (%s)" \ - % ( - self._prepared_index_name(index, - include_schema=include_schema), - preparer.format_table(index.table), - ', '.join( - self.sql_compiler.process(expr, - include_table=False, - literal_binds=True) for - expr in index.expressions) - ) + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=include_schema), + preparer.format_table(index.table), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) # handle other included columns - if index.dialect_options['mssql']['include']: - inclusions = [index.table.c[col] - if isinstance(col, util.string_types) else col - for col in - index.dialect_options['mssql']['include'] - ] + if index.dialect_options["mssql"]["include"]: + inclusions = [ + index.table.c[col] + if isinstance(col, util.string_types) + else col + for col in index.dialect_options["mssql"]["include"] + ] - text += " INCLUDE (%s)" \ - % ', '.join([preparer.quote(c.name) - for c in inclusions]) + text += " INCLUDE (%s)" % ", ".join( + [preparer.quote(c.name) for c in inclusions] + ) return text def visit_drop_index(self, drop): return "\nDROP INDEX %s ON %s" % ( self._prepared_index_name(drop.element, include_schema=False), - self.preparer.format_table(drop.element.table) + self.preparer.format_table(drop.element.table), ) def visit_primary_key_constraint(self, constraint): if len(constraint) == 0: - return '' + return "" text = "" if constraint.name is not None: - text += "CONSTRAINT %s " % \ - self.preparer.format_constraint(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint( + constraint + ) text += "PRIMARY KEY " - clustered = constraint.dialect_options['mssql']['clustered'] + clustered = constraint.dialect_options["mssql"]["clustered"] if clustered is not None: if clustered: text += "CLUSTERED " else: text += "NONCLUSTERED " - text += "(%s)" % ', '.join(self.preparer.quote(c.name) - for c in constraint) + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) for c in constraint + ) text += self.define_constraint_deferrability(constraint) return text def visit_unique_constraint(self, constraint): if len(constraint) == 0: - return '' + return "" text = "" if constraint.name is not None: - text += "CONSTRAINT %s " % \ - self.preparer.format_constraint(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint( + constraint + ) text += "UNIQUE " - clustered = constraint.dialect_options['mssql']['clustered'] + clustered = constraint.dialect_options["mssql"]["clustered"] if clustered is not None: if clustered: text += "CLUSTERED " else: text += "NONCLUSTERED " - text += "(%s)" % ', '.join(self.preparer.quote(c.name) - for c in constraint) + text += "(%s)" % ", ".join( + self.preparer.quote(c.name) for c in constraint + ) text += self.define_constraint_deferrability(constraint) return text @@ -1771,8 +1991,11 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): def __init__(self, dialect): super(MSIdentifierPreparer, self).__init__( - dialect, initial_quote='[', - final_quote=']', quote_case_sensitive_collations=False) + dialect, + initial_quote="[", + final_quote="]", + quote_case_sensitive_collations=False, + ) def _escape_identifier(self, value): return value @@ -1783,7 +2006,9 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): dbname, owner = _schema_elements(schema) if dbname: result = "%s.%s" % ( - self.quote(dbname, force), self.quote(owner, force)) + self.quote(dbname, force), + self.quote(owner, force), + ) elif owner: result = self.quote(owner, force) else: @@ -1794,16 +2019,37 @@ class MSIdentifierPreparer(compiler.IdentifierPreparer): def _db_plus_owner_listing(fn): def wrap(dialect, connection, schema=None, **kw): dbname, owner = _owner_plus_db(dialect, schema) - return _switch_db(dbname, connection, fn, dialect, connection, - dbname, owner, schema, **kw) + return _switch_db( + dbname, + connection, + fn, + dialect, + connection, + dbname, + owner, + schema, + **kw + ) + return update_wrapper(wrap, fn) def _db_plus_owner(fn): def wrap(dialect, connection, tablename, schema=None, **kw): dbname, owner = _owner_plus_db(dialect, schema) - return _switch_db(dbname, connection, fn, dialect, connection, - tablename, dbname, owner, schema, **kw) + return _switch_db( + dbname, + connection, + fn, + dialect, + connection, + tablename, + dbname, + owner, + schema, + **kw + ) + return update_wrapper(wrap, fn) @@ -1837,9 +2083,9 @@ def _schema_elements(schema): for token in re.split(r"(\[|\]|\.)", schema): if not token: continue - if token == '[': + if token == "[": bracket = True - elif token == ']': + elif token == "]": bracket = False elif not bracket and token == ".": push.append(symbol) @@ -1857,7 +2103,7 @@ def _schema_elements(schema): class MSDialect(default.DefaultDialect): - name = 'mssql' + name = "mssql" supports_default_values = True supports_empty_insert = False execution_ctx_cls = MSExecutionContext @@ -1871,9 +2117,9 @@ class MSDialect(default.DefaultDialect): sqltypes.Time: TIME, } - engine_config_types = default.DefaultDialect.engine_config_types.union([ - ('legacy_schema_aliasing', util.asbool), - ]) + engine_config_types = default.DefaultDialect.engine_config_types.union( + [("legacy_schema_aliasing", util.asbool)] + ) ischema_names = ischema_names @@ -1890,36 +2136,30 @@ class MSDialect(default.DefaultDialect): preparer = MSIdentifierPreparer construct_arguments = [ - (sa_schema.PrimaryKeyConstraint, { - "clustered": None - }), - (sa_schema.UniqueConstraint, { - "clustered": None - }), - (sa_schema.Index, { - "clustered": None, - "include": None - }), - (sa_schema.Column, { - "identity_start": 1, - "identity_increment": 1 - }) + (sa_schema.PrimaryKeyConstraint, {"clustered": None}), + (sa_schema.UniqueConstraint, {"clustered": None}), + (sa_schema.Index, {"clustered": None, "include": None}), + (sa_schema.Column, {"identity_start": 1, "identity_increment": 1}), ] - def __init__(self, - query_timeout=None, - use_scope_identity=True, - max_identifier_length=None, - schema_name="dbo", - isolation_level=None, - deprecate_large_types=None, - legacy_schema_aliasing=False, **opts): + def __init__( + self, + query_timeout=None, + use_scope_identity=True, + max_identifier_length=None, + schema_name="dbo", + isolation_level=None, + deprecate_large_types=None, + legacy_schema_aliasing=False, + **opts + ): self.query_timeout = int(query_timeout or 0) self.schema_name = schema_name self.use_scope_identity = use_scope_identity - self.max_identifier_length = int(max_identifier_length or 0) or \ - self.max_identifier_length + self.max_identifier_length = ( + int(max_identifier_length or 0) or self.max_identifier_length + ) self.deprecate_large_types = deprecate_large_types self.legacy_schema_aliasing = legacy_schema_aliasing @@ -1936,27 +2176,33 @@ class MSDialect(default.DefaultDialect): # SQL Server does not support RELEASE SAVEPOINT pass - _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED', - 'READ COMMITTED', 'REPEATABLE READ', - 'SNAPSHOT']) + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "SNAPSHOT", + ] + ) def set_isolation_level(self, connection, level): - level = level.replace('_', ' ') + level = level.replace("_", " ") if level not in self._isolation_lookup: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) ) cursor = connection.cursor() - cursor.execute( - "SET TRANSACTION ISOLATION LEVEL %s" % level) + cursor.execute("SET TRANSACTION ISOLATION LEVEL %s" % level) cursor.close() def get_isolation_level(self, connection): if self.server_version_info < MS_2005_VERSION: raise NotImplementedError( - "Can't fetch isolation level prior to SQL Server 2005") + "Can't fetch isolation level prior to SQL Server 2005" + ) last_error = None @@ -1964,7 +2210,8 @@ class MSDialect(default.DefaultDialect): for view in views: cursor = connection.cursor() try: - cursor.execute(""" + cursor.execute( + """ SELECT CASE transaction_isolation_level WHEN 0 THEN NULL WHEN 1 THEN 'READ UNCOMMITTED' @@ -1974,7 +2221,9 @@ class MSDialect(default.DefaultDialect): WHEN 5 THEN 'SNAPSHOT' END AS TRANSACTION_ISOLATION_LEVEL FROM %s where session_id = @@SPID - """ % view) + """ + % view + ) val = cursor.fetchone()[0] except self.dbapi.Error as err: # Python3 scoping rules @@ -1987,7 +2236,8 @@ class MSDialect(default.DefaultDialect): else: util.warn( "Could not fetch transaction isolation level, " - "tried views: %s; final error was: %s" % (views, last_error)) + "tried views: %s; final error was: %s" % (views, last_error) + ) raise NotImplementedError( "Can't fetch isolation level on this particular " @@ -2000,8 +2250,10 @@ class MSDialect(default.DefaultDialect): def on_connect(self): if self.isolation_level is not None: + def connect(conn): self.set_isolation_level(conn, self.isolation_level) + return connect else: return None @@ -2010,16 +2262,20 @@ class MSDialect(default.DefaultDialect): if self.server_version_info[0] not in list(range(8, 17)): util.warn( "Unrecognized server version info '%s'. Some SQL Server " - "features may not function properly." % - ".".join(str(x) for x in self.server_version_info)) - if self.server_version_info >= MS_2005_VERSION and \ - 'implicit_returning' not in self.__dict__: + "features may not function properly." + % ".".join(str(x) for x in self.server_version_info) + ) + if ( + self.server_version_info >= MS_2005_VERSION + and "implicit_returning" not in self.__dict__ + ): self.implicit_returning = True if self.server_version_info >= MS_2008_VERSION: self.supports_multivalues_insert = True if self.deprecate_large_types is None: - self.deprecate_large_types = \ + self.deprecate_large_types = ( self.server_version_info >= MS_2012_VERSION + ) def _get_default_schema_name(self, connection): if self.server_version_info < MS_2005_VERSION: @@ -2039,17 +2295,19 @@ class MSDialect(default.DefaultDialect): whereclause = columns.c.table_name == tablename if owner: - whereclause = sql.and_(whereclause, - columns.c.table_schema == owner) + whereclause = sql.and_( + whereclause, columns.c.table_schema == owner + ) s = sql.select([columns], whereclause) c = connection.execute(s) return c.first() is not None @reflection.cache def get_schema_names(self, connection, **kw): - s = sql.select([ischema.schemata.c.schema_name], - order_by=[ischema.schemata.c.schema_name] - ) + s = sql.select( + [ischema.schemata.c.schema_name], + order_by=[ischema.schemata.c.schema_name], + ) schema_names = [r[0] for r in connection.execute(s)] return schema_names @@ -2057,12 +2315,13 @@ class MSDialect(default.DefaultDialect): @_db_plus_owner_listing def get_table_names(self, connection, dbname, owner, schema, **kw): tables = ischema.tables - s = sql.select([tables.c.table_name], - sql.and_( - tables.c.table_schema == owner, - tables.c.table_type == 'BASE TABLE' - ), - order_by=[tables.c.table_name] + s = sql.select( + [tables.c.table_name], + sql.and_( + tables.c.table_schema == owner, + tables.c.table_type == "BASE TABLE", + ), + order_by=[tables.c.table_name], ) table_names = [r[0] for r in connection.execute(s)] return table_names @@ -2071,12 +2330,12 @@ class MSDialect(default.DefaultDialect): @_db_plus_owner_listing def get_view_names(self, connection, dbname, owner, schema, **kw): tables = ischema.tables - s = sql.select([tables.c.table_name], - sql.and_( - tables.c.table_schema == owner, - tables.c.table_type == 'VIEW' - ), - order_by=[tables.c.table_name] + s = sql.select( + [tables.c.table_name], + sql.and_( + tables.c.table_schema == owner, tables.c.table_type == "VIEW" + ), + order_by=[tables.c.table_name], ) view_names = [r[0] for r in connection.execute(s)] return view_names @@ -2090,30 +2349,33 @@ class MSDialect(default.DefaultDialect): return [] rp = connection.execute( - sql.text("select ind.index_id, ind.is_unique, ind.name " - "from sys.indexes as ind join sys.tables as tab on " - "ind.object_id=tab.object_id " - "join sys.schemas as sch on sch.schema_id=tab.schema_id " - "where tab.name = :tabname " - "and sch.name=:schname " - "and ind.is_primary_key=0 and ind.type != 0", - bindparams=[ - sql.bindparam('tabname', tablename, - sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', owner, - sqltypes.String(convert_unicode=True)) - ], - typemap={ - 'name': sqltypes.Unicode() - } - ) + sql.text( + "select ind.index_id, ind.is_unique, ind.name " + "from sys.indexes as ind join sys.tables as tab on " + "ind.object_id=tab.object_id " + "join sys.schemas as sch on sch.schema_id=tab.schema_id " + "where tab.name = :tabname " + "and sch.name=:schname " + "and ind.is_primary_key=0 and ind.type != 0", + bindparams=[ + sql.bindparam( + "tabname", + tablename, + sqltypes.String(convert_unicode=True), + ), + sql.bindparam( + "schname", owner, sqltypes.String(convert_unicode=True) + ), + ], + typemap={"name": sqltypes.Unicode()}, + ) ) indexes = {} for row in rp: - indexes[row['index_id']] = { - 'name': row['name'], - 'unique': row['is_unique'] == 1, - 'column_names': [] + indexes[row["index_id"]] = { + "name": row["name"], + "unique": row["is_unique"] == 1, + "column_names": [], } rp = connection.execute( sql.text( @@ -2127,24 +2389,29 @@ class MSDialect(default.DefaultDialect): "where tab.name=:tabname " "and sch.name=:schname", bindparams=[ - sql.bindparam('tabname', tablename, - sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', owner, - sqltypes.String(convert_unicode=True)) + sql.bindparam( + "tabname", + tablename, + sqltypes.String(convert_unicode=True), + ), + sql.bindparam( + "schname", owner, sqltypes.String(convert_unicode=True) + ), ], - typemap={'name': sqltypes.Unicode()} - ), + typemap={"name": sqltypes.Unicode()}, + ) ) for row in rp: - if row['index_id'] in indexes: - indexes[row['index_id']]['column_names'].append(row['name']) + if row["index_id"] in indexes: + indexes[row["index_id"]]["column_names"].append(row["name"]) return list(indexes.values()) @reflection.cache @_db_plus_owner - def get_view_definition(self, connection, viewname, - dbname, owner, schema, **kw): + def get_view_definition( + self, connection, viewname, dbname, owner, schema, **kw + ): rp = connection.execute( sql.text( "select definition from sys.sql_modules as mod, " @@ -2155,11 +2422,15 @@ class MSDialect(default.DefaultDialect): "views.schema_id=sch.schema_id and " "views.name=:viewname and sch.name=:schname", bindparams=[ - sql.bindparam('viewname', viewname, - sqltypes.String(convert_unicode=True)), - sql.bindparam('schname', owner, - sqltypes.String(convert_unicode=True)) - ] + sql.bindparam( + "viewname", + viewname, + sqltypes.String(convert_unicode=True), + ), + sql.bindparam( + "schname", owner, sqltypes.String(convert_unicode=True) + ), + ], ) ) @@ -2173,12 +2444,15 @@ class MSDialect(default.DefaultDialect): # Get base columns columns = ischema.columns if owner: - whereclause = sql.and_(columns.c.table_name == tablename, - columns.c.table_schema == owner) + whereclause = sql.and_( + columns.c.table_name == tablename, + columns.c.table_schema == owner, + ) else: whereclause = columns.c.table_name == tablename - s = sql.select([columns], whereclause, - order_by=[columns.c.ordinal_position]) + s = sql.select( + [columns], whereclause, order_by=[columns.c.ordinal_position] + ) c = connection.execute(s) cols = [] @@ -2186,57 +2460,76 @@ class MSDialect(default.DefaultDialect): row = c.fetchone() if row is None: break - (name, type, nullable, charlen, - numericprec, numericscale, default, collation) = ( + ( + name, + type, + nullable, + charlen, + numericprec, + numericscale, + default, + collation, + ) = ( row[columns.c.column_name], row[columns.c.data_type], - row[columns.c.is_nullable] == 'YES', + row[columns.c.is_nullable] == "YES", row[columns.c.character_maximum_length], row[columns.c.numeric_precision], row[columns.c.numeric_scale], row[columns.c.column_default], - row[columns.c.collation_name] + row[columns.c.collation_name], ) coltype = self.ischema_names.get(type, None) kwargs = {} - if coltype in (MSString, MSChar, MSNVarchar, MSNChar, MSText, - MSNText, MSBinary, MSVarBinary, - sqltypes.LargeBinary): + if coltype in ( + MSString, + MSChar, + MSNVarchar, + MSNChar, + MSText, + MSNText, + MSBinary, + MSVarBinary, + sqltypes.LargeBinary, + ): if charlen == -1: charlen = None - kwargs['length'] = charlen + kwargs["length"] = charlen if collation: - kwargs['collation'] = collation + kwargs["collation"] = collation if coltype is None: util.warn( - "Did not recognize type '%s' of column '%s'" % - (type, name)) + "Did not recognize type '%s' of column '%s'" % (type, name) + ) coltype = sqltypes.NULLTYPE else: - if issubclass(coltype, sqltypes.Numeric) and \ - coltype is not MSReal: - kwargs['scale'] = numericscale - kwargs['precision'] = numericprec + if ( + issubclass(coltype, sqltypes.Numeric) + and coltype is not MSReal + ): + kwargs["scale"] = numericscale + kwargs["precision"] = numericprec coltype = coltype(**kwargs) cdict = { - 'name': name, - 'type': coltype, - 'nullable': nullable, - 'default': default, - 'autoincrement': False, + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": False, } cols.append(cdict) # autoincrement and identity colmap = {} for col in cols: - colmap[col['name']] = col + colmap[col["name"]] = col # We also run an sp_columns to check for identity columns: - cursor = connection.execute("sp_columns @table_name = '%s', " - "@table_owner = '%s'" - % (tablename, owner)) + cursor = connection.execute( + "sp_columns @table_name = '%s', " + "@table_owner = '%s'" % (tablename, owner) + ) ic = None while True: row = cursor.fetchone() @@ -2245,10 +2538,10 @@ class MSDialect(default.DefaultDialect): (col_name, type_name) = row[3], row[5] if type_name.endswith("identity") and col_name in colmap: ic = col_name - colmap[col_name]['autoincrement'] = True - colmap[col_name]['dialect_options'] = { - 'mssql_identity_start': 1, - 'mssql_identity_increment': 1 + colmap[col_name]["autoincrement"] = True + colmap[col_name]["dialect_options"] = { + "mssql_identity_start": 1, + "mssql_identity_increment": 1, } break cursor.close() @@ -2262,64 +2555,74 @@ class MSDialect(default.DefaultDialect): row = cursor.first() if row is not None and row[0] is not None: - colmap[ic]['dialect_options'].update({ - 'mssql_identity_start': int(row[0]), - 'mssql_identity_increment': int(row[1]) - }) + colmap[ic]["dialect_options"].update( + { + "mssql_identity_start": int(row[0]), + "mssql_identity_increment": int(row[1]), + } + ) return cols @reflection.cache @_db_plus_owner - def get_pk_constraint(self, connection, tablename, - dbname, owner, schema, **kw): + def get_pk_constraint( + self, connection, tablename, dbname, owner, schema, **kw + ): pkeys = [] TC = ischema.constraints - C = ischema.key_constraints.alias('C') + C = ischema.key_constraints.alias("C") # Primary key constraints - s = sql.select([C.c.column_name, - TC.c.constraint_type, - C.c.constraint_name], - sql.and_(TC.c.constraint_name == C.c.constraint_name, - TC.c.table_schema == C.c.table_schema, - C.c.table_name == tablename, - C.c.table_schema == owner) - ) + s = sql.select( + [C.c.column_name, TC.c.constraint_type, C.c.constraint_name], + sql.and_( + TC.c.constraint_name == C.c.constraint_name, + TC.c.table_schema == C.c.table_schema, + C.c.table_name == tablename, + C.c.table_schema == owner, + ), + ) c = connection.execute(s) constraint_name = None for row in c: - if 'PRIMARY' in row[TC.c.constraint_type.name]: + if "PRIMARY" in row[TC.c.constraint_type.name]: pkeys.append(row[0]) if constraint_name is None: constraint_name = row[C.c.constraint_name.name] - return {'constrained_columns': pkeys, 'name': constraint_name} + return {"constrained_columns": pkeys, "name": constraint_name} @reflection.cache @_db_plus_owner - def get_foreign_keys(self, connection, tablename, - dbname, owner, schema, **kw): + def get_foreign_keys( + self, connection, tablename, dbname, owner, schema, **kw + ): RR = ischema.ref_constraints - C = ischema.key_constraints.alias('C') - R = ischema.key_constraints.alias('R') + C = ischema.key_constraints.alias("C") + R = ischema.key_constraints.alias("R") # Foreign key constraints - s = sql.select([C.c.column_name, - R.c.table_schema, R.c.table_name, R.c.column_name, - RR.c.constraint_name, RR.c.match_option, - RR.c.update_rule, - RR.c.delete_rule], - sql.and_(C.c.table_name == tablename, - C.c.table_schema == owner, - RR.c.constraint_schema == C.c.table_schema, - C.c.constraint_name == RR.c.constraint_name, - R.c.constraint_name == - RR.c.unique_constraint_name, - R.c.constraint_schema == - RR.c.unique_constraint_schema, - C.c.ordinal_position == R.c.ordinal_position - ), - order_by=[RR.c.constraint_name, R.c.ordinal_position] - ) + s = sql.select( + [ + C.c.column_name, + R.c.table_schema, + R.c.table_name, + R.c.column_name, + RR.c.constraint_name, + RR.c.match_option, + RR.c.update_rule, + RR.c.delete_rule, + ], + sql.and_( + C.c.table_name == tablename, + C.c.table_schema == owner, + RR.c.constraint_schema == C.c.table_schema, + C.c.constraint_name == RR.c.constraint_name, + R.c.constraint_name == RR.c.unique_constraint_name, + R.c.constraint_schema == RR.c.unique_constraint_schema, + C.c.ordinal_position == R.c.ordinal_position, + ), + order_by=[RR.c.constraint_name, R.c.ordinal_position], + ) # group rows by constraint ID, to handle multi-column FKs fkeys = [] @@ -2327,11 +2630,11 @@ class MSDialect(default.DefaultDialect): def fkey_rec(): return { - 'name': None, - 'constrained_columns': [], - 'referred_schema': None, - 'referred_table': None, - 'referred_columns': [] + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], } fkeys = util.defaultdict(fkey_rec) @@ -2340,17 +2643,18 @@ class MSDialect(default.DefaultDialect): scol, rschema, rtbl, rcol, rfknm, fkmatch, fkuprule, fkdelrule = r rec = fkeys[rfknm] - rec['name'] = rfknm - if not rec['referred_table']: - rec['referred_table'] = rtbl + rec["name"] = rfknm + if not rec["referred_table"]: + rec["referred_table"] = rtbl if schema is not None or owner != rschema: if dbname: rschema = dbname + "." + rschema - rec['referred_schema'] = rschema + rec["referred_schema"] = rschema - local_cols, remote_cols = \ - rec['constrained_columns'],\ - rec['referred_columns'] + local_cols, remote_cols = ( + rec["constrained_columns"], + rec["referred_columns"], + ) local_cols.append(scol) remote_cols.append(rcol) diff --git a/lib/sqlalchemy/dialects/mssql/information_schema.py b/lib/sqlalchemy/dialects/mssql/information_schema.py index 3682fae48..c4ea8ab0c 100644 --- a/lib/sqlalchemy/dialects/mssql/information_schema.py +++ b/lib/sqlalchemy/dialects/mssql/information_schema.py @@ -38,102 +38,122 @@ class _cast_on_2005(expression.ColumnElement): @compiles(_cast_on_2005) def _compile(element, compiler, **kw): from . import base - if compiler.dialect.server_version_info is None or \ - compiler.dialect.server_version_info < base.MS_2005_VERSION: + + if ( + compiler.dialect.server_version_info is None + or compiler.dialect.server_version_info < base.MS_2005_VERSION + ): return compiler.process(element.bindvalue, **kw) else: return compiler.process(cast(element.bindvalue, Unicode), **kw) -schemata = Table("SCHEMATA", ischema, - Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"), - Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"), - Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"), - schema="INFORMATION_SCHEMA") - -tables = Table("TABLES", ischema, - Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), - Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, key="table_name"), - Column( - "TABLE_TYPE", String(convert_unicode=True), - key="table_type"), - schema="INFORMATION_SCHEMA") - -columns = Table("COLUMNS", ischema, - Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, key="table_name"), - Column("COLUMN_NAME", CoerceUnicode, key="column_name"), - Column("IS_NULLABLE", Integer, key="is_nullable"), - Column("DATA_TYPE", String, key="data_type"), - Column("ORDINAL_POSITION", Integer, key="ordinal_position"), - Column("CHARACTER_MAXIMUM_LENGTH", Integer, - key="character_maximum_length"), - Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), - Column("NUMERIC_SCALE", Integer, key="numeric_scale"), - Column("COLUMN_DEFAULT", Integer, key="column_default"), - Column("COLLATION_NAME", String, key="collation_name"), - schema="INFORMATION_SCHEMA") - -constraints = Table("TABLE_CONSTRAINTS", ischema, - Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, key="table_name"), - Column("CONSTRAINT_NAME", CoerceUnicode, - key="constraint_name"), - Column("CONSTRAINT_TYPE", String( - convert_unicode=True), key="constraint_type"), - schema="INFORMATION_SCHEMA") - -column_constraints = Table("CONSTRAINT_COLUMN_USAGE", ischema, - Column("TABLE_SCHEMA", CoerceUnicode, - key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, - key="table_name"), - Column("COLUMN_NAME", CoerceUnicode, - key="column_name"), - Column("CONSTRAINT_NAME", CoerceUnicode, - key="constraint_name"), - schema="INFORMATION_SCHEMA") - -key_constraints = Table("KEY_COLUMN_USAGE", ischema, - Column("TABLE_SCHEMA", CoerceUnicode, - key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, - key="table_name"), - Column("COLUMN_NAME", CoerceUnicode, - key="column_name"), - Column("CONSTRAINT_NAME", CoerceUnicode, - key="constraint_name"), - Column("CONSTRAINT_SCHEMA", CoerceUnicode, - key="constraint_schema"), - Column("ORDINAL_POSITION", Integer, - key="ordinal_position"), - schema="INFORMATION_SCHEMA") - -ref_constraints = Table("REFERENTIAL_CONSTRAINTS", ischema, - Column("CONSTRAINT_CATALOG", CoerceUnicode, - key="constraint_catalog"), - Column("CONSTRAINT_SCHEMA", CoerceUnicode, - key="constraint_schema"), - Column("CONSTRAINT_NAME", CoerceUnicode, - key="constraint_name"), - # TODO: is CATLOG misspelled ? - Column("UNIQUE_CONSTRAINT_CATLOG", CoerceUnicode, - key="unique_constraint_catalog"), - - Column("UNIQUE_CONSTRAINT_SCHEMA", CoerceUnicode, - key="unique_constraint_schema"), - Column("UNIQUE_CONSTRAINT_NAME", CoerceUnicode, - key="unique_constraint_name"), - Column("MATCH_OPTION", String, key="match_option"), - Column("UPDATE_RULE", String, key="update_rule"), - Column("DELETE_RULE", String, key="delete_rule"), - schema="INFORMATION_SCHEMA") - -views = Table("VIEWS", ischema, - Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), - Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), - Column("TABLE_NAME", CoerceUnicode, key="table_name"), - Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"), - Column("CHECK_OPTION", String, key="check_option"), - Column("IS_UPDATABLE", String, key="is_updatable"), - schema="INFORMATION_SCHEMA") + +schemata = Table( + "SCHEMATA", + ischema, + Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"), + Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"), + Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"), + schema="INFORMATION_SCHEMA", +) + +tables = Table( + "TABLES", + ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("TABLE_TYPE", String(convert_unicode=True), key="table_type"), + schema="INFORMATION_SCHEMA", +) + +columns = Table( + "COLUMNS", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("IS_NULLABLE", Integer, key="is_nullable"), + Column("DATA_TYPE", String, key="data_type"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + Column( + "CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length" + ), + Column("NUMERIC_PRECISION", Integer, key="numeric_precision"), + Column("NUMERIC_SCALE", Integer, key="numeric_scale"), + Column("COLUMN_DEFAULT", Integer, key="column_default"), + Column("COLLATION_NAME", String, key="collation_name"), + schema="INFORMATION_SCHEMA", +) + +constraints = Table( + "TABLE_CONSTRAINTS", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column( + "CONSTRAINT_TYPE", String(convert_unicode=True), key="constraint_type" + ), + schema="INFORMATION_SCHEMA", +) + +column_constraints = Table( + "CONSTRAINT_COLUMN_USAGE", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + schema="INFORMATION_SCHEMA", +) + +key_constraints = Table( + "KEY_COLUMN_USAGE", + ischema, + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("COLUMN_NAME", CoerceUnicode, key="column_name"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), + Column("ORDINAL_POSITION", Integer, key="ordinal_position"), + schema="INFORMATION_SCHEMA", +) + +ref_constraints = Table( + "REFERENTIAL_CONSTRAINTS", + ischema, + Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"), + Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"), + Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"), + # TODO: is CATLOG misspelled ? + Column( + "UNIQUE_CONSTRAINT_CATLOG", + CoerceUnicode, + key="unique_constraint_catalog", + ), + Column( + "UNIQUE_CONSTRAINT_SCHEMA", + CoerceUnicode, + key="unique_constraint_schema", + ), + Column( + "UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name" + ), + Column("MATCH_OPTION", String, key="match_option"), + Column("UPDATE_RULE", String, key="update_rule"), + Column("DELETE_RULE", String, key="delete_rule"), + schema="INFORMATION_SCHEMA", +) + +views = Table( + "VIEWS", + ischema, + Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"), + Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"), + Column("TABLE_NAME", CoerceUnicode, key="table_name"), + Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"), + Column("CHECK_OPTION", String, key="check_option"), + Column("IS_UPDATABLE", String, key="is_updatable"), + schema="INFORMATION_SCHEMA", +) diff --git a/lib/sqlalchemy/dialects/mssql/mxodbc.py b/lib/sqlalchemy/dialects/mssql/mxodbc.py index 8983a3b60..3b9ea2707 100644 --- a/lib/sqlalchemy/dialects/mssql/mxodbc.py +++ b/lib/sqlalchemy/dialects/mssql/mxodbc.py @@ -46,10 +46,14 @@ of ``False`` will unconditionally use string-escaped parameters. from ... import types as sqltypes from ...connectors.mxodbc import MxODBCConnector from .pyodbc import MSExecutionContext_pyodbc, _MSNumeric_pyodbc -from .base import (MSDialect, - MSSQLStrictCompiler, - VARBINARY, - _MSDateTime, _MSDate, _MSTime) +from .base import ( + MSDialect, + MSSQLStrictCompiler, + VARBINARY, + _MSDateTime, + _MSDate, + _MSTime, +) class _MSNumeric_mxodbc(_MSNumeric_pyodbc): @@ -64,6 +68,7 @@ class _MSDate_mxodbc(_MSDate): return "%s-%s-%s" % (value.year, value.month, value.day) else: return None + return process @@ -74,6 +79,7 @@ class _MSTime_mxodbc(_MSTime): return "%s:%s:%s" % (value.hour, value.minute, value.second) else: return None + return process @@ -98,6 +104,7 @@ class _VARBINARY_mxodbc(VARBINARY): else: # should pull from mx.ODBC.Manager.BinaryNull return dialect.dbapi.BinaryNull + return process @@ -107,6 +114,7 @@ class MSExecutionContext_mxodbc(MSExecutionContext_pyodbc): SELECT SCOPE_IDENTITY in cases where OUTPUT clause does not work (tables with insert triggers). """ + # todo - investigate whether the pyodbc execution context # is really only being used in cases where OUTPUT # won't work. @@ -136,4 +144,5 @@ class MSDialect_mxodbc(MxODBCConnector, MSDialect): super(MSDialect_mxodbc, self).__init__(**params) self.description_encoding = description_encoding + dialect = MSDialect_mxodbc diff --git a/lib/sqlalchemy/dialects/mssql/pymssql.py b/lib/sqlalchemy/dialects/mssql/pymssql.py index 8589c8b06..847c00329 100644 --- a/lib/sqlalchemy/dialects/mssql/pymssql.py +++ b/lib/sqlalchemy/dialects/mssql/pymssql.py @@ -35,7 +35,6 @@ class _MSNumeric_pymssql(sqltypes.Numeric): class MSIdentifierPreparer_pymssql(MSIdentifierPreparer): - def __init__(self, dialect): super(MSIdentifierPreparer_pymssql, self).__init__(dialect) # pymssql has the very unusual behavior that it uses pyformat @@ -45,47 +44,45 @@ class MSIdentifierPreparer_pymssql(MSIdentifierPreparer): class MSDialect_pymssql(MSDialect): supports_native_decimal = True - driver = 'pymssql' + driver = "pymssql" preparer = MSIdentifierPreparer_pymssql colspecs = util.update_copy( MSDialect.colspecs, - { - sqltypes.Numeric: _MSNumeric_pymssql, - sqltypes.Float: sqltypes.Float, - } + {sqltypes.Numeric: _MSNumeric_pymssql, sqltypes.Float: sqltypes.Float}, ) @classmethod def dbapi(cls): - module = __import__('pymssql') + module = __import__("pymssql") # pymmsql < 2.1.1 doesn't have a Binary method. we use string client_ver = tuple(int(x) for x in module.__version__.split(".")) if client_ver < (2, 1, 1): # TODO: monkeypatching here is less than ideal - module.Binary = lambda x: x if hasattr(x, 'decode') else str(x) + module.Binary = lambda x: x if hasattr(x, "decode") else str(x) - if client_ver < (1, ): - util.warn("The pymssql dialect expects at least " - "the 1.0 series of the pymssql DBAPI.") + if client_ver < (1,): + util.warn( + "The pymssql dialect expects at least " + "the 1.0 series of the pymssql DBAPI." + ) return module def _get_server_version_info(self, connection): vers = connection.scalar("select @@version") - m = re.match( - r"Microsoft .*? - (\d+).(\d+).(\d+).(\d+)", vers) + m = re.match(r"Microsoft .*? - (\d+).(\d+).(\d+).(\d+)", vers) if m: return tuple(int(x) for x in m.group(1, 2, 3, 4)) else: return None def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') + opts = url.translate_connect_args(username="user") opts.update(url.query) - port = opts.pop('port', None) - if port and 'host' in opts: - opts['host'] = "%s:%s" % (opts['host'], port) + port = opts.pop("port", None) + if port and "host" in opts: + opts["host"] = "%s:%s" % (opts["host"], port) return [[], opts] def is_disconnect(self, e, connection, cursor): @@ -105,12 +102,13 @@ class MSDialect_pymssql(MSDialect): return False def set_isolation_level(self, connection, level): - if level == 'AUTOCOMMIT': + if level == "AUTOCOMMIT": connection.autocommit(True) else: connection.autocommit(False) - super(MSDialect_pymssql, self).set_isolation_level(connection, - level) + super(MSDialect_pymssql, self).set_isolation_level( + connection, level + ) dialect = MSDialect_pymssql diff --git a/lib/sqlalchemy/dialects/mssql/pyodbc.py b/lib/sqlalchemy/dialects/mssql/pyodbc.py index 34f81d6e8..db5573c2c 100644 --- a/lib/sqlalchemy/dialects/mssql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mssql/pyodbc.py @@ -132,15 +132,13 @@ class _ms_numeric_pyodbc(object): def bind_processor(self, dialect): - super_process = super(_ms_numeric_pyodbc, self).\ - bind_processor(dialect) + super_process = super(_ms_numeric_pyodbc, self).bind_processor(dialect) if not dialect._need_decimal_fix: return super_process def process(value): - if self.asdecimal and \ - isinstance(value, decimal.Decimal): + if self.asdecimal and isinstance(value, decimal.Decimal): adjusted = value.adjusted() if adjusted < 0: return self._small_dec_to_string(value) @@ -151,6 +149,7 @@ class _ms_numeric_pyodbc(object): return super_process(value) else: return value + return process # these routines needed for older versions of pyodbc. @@ -158,30 +157,31 @@ class _ms_numeric_pyodbc(object): def _small_dec_to_string(self, value): return "%s0.%s%s" % ( - (value < 0 and '-' or ''), - '0' * (abs(value.adjusted()) - 1), - "".join([str(nint) for nint in value.as_tuple()[1]])) + (value < 0 and "-" or ""), + "0" * (abs(value.adjusted()) - 1), + "".join([str(nint) for nint in value.as_tuple()[1]]), + ) def _large_dec_to_string(self, value): _int = value.as_tuple()[1] - if 'E' in str(value): + if "E" in str(value): result = "%s%s%s" % ( - (value < 0 and '-' or ''), + (value < 0 and "-" or ""), "".join([str(s) for s in _int]), - "0" * (value.adjusted() - (len(_int) - 1))) + "0" * (value.adjusted() - (len(_int) - 1)), + ) else: if (len(_int) - 1) > value.adjusted(): result = "%s%s.%s" % ( - (value < 0 and '-' or ''), - "".join( - [str(s) for s in _int][0:value.adjusted() + 1]), - "".join( - [str(s) for s in _int][value.adjusted() + 1:])) + (value < 0 and "-" or ""), + "".join([str(s) for s in _int][0 : value.adjusted() + 1]), + "".join([str(s) for s in _int][value.adjusted() + 1 :]), + ) else: result = "%s%s" % ( - (value < 0 and '-' or ''), - "".join( - [str(s) for s in _int][0:value.adjusted() + 1])) + (value < 0 and "-" or ""), + "".join([str(s) for s in _int][0 : value.adjusted() + 1]), + ) return result @@ -212,6 +212,7 @@ class _ms_binary_pyodbc(object): else: # pyodbc-specific return dialect.dbapi.BinaryNull + return process @@ -243,9 +244,11 @@ class MSExecutionContext_pyodbc(MSExecutionContext): # don't embed the scope_identity select into an # "INSERT .. DEFAULT VALUES" - if self._select_lastrowid and \ - self.dialect.use_scope_identity and \ - len(self.parameters[0]): + if ( + self._select_lastrowid + and self.dialect.use_scope_identity + and len(self.parameters[0]) + ): self._embedded_scope_identity = True self.statement += "; select scope_identity()" @@ -281,26 +284,31 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): sqltypes.Numeric: _MSNumeric_pyodbc, sqltypes.Float: _MSFloat_pyodbc, BINARY: _BINARY_pyodbc, - # SQL Server dialect has a VARBINARY that is just to support # "deprecate_large_types" w/ VARBINARY(max), but also we must # handle the usual SQL standard VARBINARY VARBINARY: _VARBINARY_pyodbc, sqltypes.VARBINARY: _VARBINARY_pyodbc, sqltypes.LargeBinary: _VARBINARY_pyodbc, - } + }, ) - def __init__(self, description_encoding=None, fast_executemany=False, - **params): - if 'description_encoding' in params: - self.description_encoding = params.pop('description_encoding') + def __init__( + self, description_encoding=None, fast_executemany=False, **params + ): + if "description_encoding" in params: + self.description_encoding = params.pop("description_encoding") super(MSDialect_pyodbc, self).__init__(**params) - self.use_scope_identity = self.use_scope_identity and \ - self.dbapi and \ - hasattr(self.dbapi.Cursor, 'nextset') - self._need_decimal_fix = self.dbapi and \ - self._dbapi_version() < (2, 1, 8) + self.use_scope_identity = ( + self.use_scope_identity + and self.dbapi + and hasattr(self.dbapi.Cursor, "nextset") + ) + self._need_decimal_fix = self.dbapi and self._dbapi_version() < ( + 2, + 1, + 8, + ) self.fast_executemany = fast_executemany def _get_server_version_info(self, connection): @@ -308,16 +316,18 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): # "Version of the instance of SQL Server, in the form # of 'major.minor.build.revision'" raw = connection.scalar( - "SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)") + "SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)" + ) except exc.DBAPIError: # SQL Server docs indicate this function isn't present prior to # 2008. Before we had the VARCHAR cast above, pyodbc would also # fail on this query. - return super(MSDialect_pyodbc, self).\ - _get_server_version_info(connection, allow_chars=False) + return super(MSDialect_pyodbc, self)._get_server_version_info( + connection, allow_chars=False + ) else: version = [] - r = re.compile(r'[.\-]') + r = re.compile(r"[.\-]") for n in r.split(raw): try: version.append(int(n)) @@ -329,17 +339,27 @@ class MSDialect_pyodbc(PyODBCConnector, MSDialect): if self.fast_executemany: cursor.fast_executemany = True super(MSDialect_pyodbc, self).do_executemany( - cursor, statement, parameters, context=context) + cursor, statement, parameters, context=context + ) def is_disconnect(self, e, connection, cursor): if isinstance(e, self.dbapi.Error): for code in ( - '08S01', '01002', '08003', '08007', - '08S02', '08001', 'HYT00', 'HY010', - '10054'): + "08S01", + "01002", + "08003", + "08007", + "08S02", + "08001", + "HYT00", + "HY010", + "10054", + ): if code in str(e): return True return super(MSDialect_pyodbc, self).is_disconnect( - e, connection, cursor) + e, connection, cursor + ) + dialect = MSDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/mssql/zxjdbc.py b/lib/sqlalchemy/dialects/mssql/zxjdbc.py index 3fb93b28a..13fc46e19 100644 --- a/lib/sqlalchemy/dialects/mssql/zxjdbc.py +++ b/lib/sqlalchemy/dialects/mssql/zxjdbc.py @@ -44,26 +44,28 @@ class MSExecutionContext_zxjdbc(MSExecutionContext): self.cursor.nextset() self._lastrowid = int(row[0]) - if (self.isinsert or self.isupdate or self.isdelete) and \ - self.compiled.returning: + if ( + self.isinsert or self.isupdate or self.isdelete + ) and self.compiled.returning: self._result_proxy = engine.FullyBufferedResultProxy(self) if self._enable_identity_insert: table = self.dialect.identifier_preparer.format_table( - self.compiled.statement.table) + self.compiled.statement.table + ) self.cursor.execute("SET IDENTITY_INSERT %s OFF" % table) class MSDialect_zxjdbc(ZxJDBCConnector, MSDialect): - jdbc_db_name = 'jtds:sqlserver' - jdbc_driver_name = 'net.sourceforge.jtds.jdbc.Driver' + jdbc_db_name = "jtds:sqlserver" + jdbc_driver_name = "net.sourceforge.jtds.jdbc.Driver" execution_ctx_cls = MSExecutionContext_zxjdbc def _get_server_version_info(self, connection): return tuple( - int(x) - for x in connection.connection.dbversion.split('.') + int(x) for x in connection.connection.dbversion.split(".") ) + dialect = MSDialect_zxjdbc diff --git a/lib/sqlalchemy/dialects/mysql/__init__.py b/lib/sqlalchemy/dialects/mysql/__init__.py index de4e1fa41..ffeb8f486 100644 --- a/lib/sqlalchemy/dialects/mysql/__init__.py +++ b/lib/sqlalchemy/dialects/mysql/__init__.py @@ -5,18 +5,56 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from . import base, mysqldb, oursql, \ - pyodbc, zxjdbc, mysqlconnector, pymysql, \ - gaerdbms, cymysql +from . import ( + base, + mysqldb, + oursql, + pyodbc, + zxjdbc, + mysqlconnector, + pymysql, + gaerdbms, + cymysql, +) -from .base import \ - BIGINT, BINARY, BIT, BLOB, BOOLEAN, CHAR, DATE, DATETIME, \ - DECIMAL, DOUBLE, ENUM, DECIMAL,\ - FLOAT, INTEGER, INTEGER, JSON, LONGBLOB, LONGTEXT, MEDIUMBLOB, \ - MEDIUMINT, MEDIUMTEXT, NCHAR, \ - NVARCHAR, NUMERIC, SET, SMALLINT, REAL, TEXT, TIME, TIMESTAMP, \ - TINYBLOB, TINYINT, TINYTEXT,\ - VARBINARY, VARCHAR, YEAR +from .base import ( + BIGINT, + BINARY, + BIT, + BLOB, + BOOLEAN, + CHAR, + DATE, + DATETIME, + DECIMAL, + DOUBLE, + ENUM, + DECIMAL, + FLOAT, + INTEGER, + INTEGER, + JSON, + LONGBLOB, + LONGTEXT, + MEDIUMBLOB, + MEDIUMINT, + MEDIUMTEXT, + NCHAR, + NVARCHAR, + NUMERIC, + SET, + SMALLINT, + REAL, + TEXT, + TIME, + TIMESTAMP, + TINYBLOB, + TINYINT, + TINYTEXT, + VARBINARY, + VARCHAR, + YEAR, +) from .dml import insert, Insert @@ -25,10 +63,41 @@ base.dialect = dialect = mysqldb.dialect __all__ = ( - 'BIGINT', 'BINARY', 'BIT', 'BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', - 'DECIMAL', 'DOUBLE', 'ENUM', 'DECIMAL', 'FLOAT', 'INTEGER', 'INTEGER', - 'JSON', 'LONGBLOB', 'LONGTEXT', 'MEDIUMBLOB', 'MEDIUMINT', 'MEDIUMTEXT', - 'NCHAR', 'NVARCHAR', 'NUMERIC', 'SET', 'SMALLINT', 'REAL', 'TEXT', 'TIME', - 'TIMESTAMP', 'TINYBLOB', 'TINYINT', 'TINYTEXT', 'VARBINARY', 'VARCHAR', - 'YEAR', 'dialect' + "BIGINT", + "BINARY", + "BIT", + "BLOB", + "BOOLEAN", + "CHAR", + "DATE", + "DATETIME", + "DECIMAL", + "DOUBLE", + "ENUM", + "DECIMAL", + "FLOAT", + "INTEGER", + "INTEGER", + "JSON", + "LONGBLOB", + "LONGTEXT", + "MEDIUMBLOB", + "MEDIUMINT", + "MEDIUMTEXT", + "NCHAR", + "NVARCHAR", + "NUMERIC", + "SET", + "SMALLINT", + "REAL", + "TEXT", + "TIME", + "TIMESTAMP", + "TINYBLOB", + "TINYINT", + "TINYTEXT", + "VARBINARY", + "VARCHAR", + "YEAR", + "dialect", ) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 673d4b9ff..7b0d0618c 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -746,85 +746,340 @@ from ...engine import reflection from ...engine import default from ... import types as sqltypes from ...util import topological -from ...types import DATE, BOOLEAN, \ - BLOB, BINARY, VARBINARY +from ...types import DATE, BOOLEAN, BLOB, BINARY, VARBINARY from . import reflection as _reflection -from .types import BIGINT, BIT, CHAR, DECIMAL, DATETIME, \ - DOUBLE, FLOAT, INTEGER, LONGBLOB, LONGTEXT, MEDIUMBLOB, MEDIUMINT, \ - MEDIUMTEXT, NCHAR, NUMERIC, NVARCHAR, REAL, SMALLINT, TEXT, TIME, \ - TIMESTAMP, TINYBLOB, TINYINT, TINYTEXT, VARCHAR, YEAR -from .types import _StringType, _IntegerType, _NumericType, \ - _FloatType, _MatchType +from .types import ( + BIGINT, + BIT, + CHAR, + DECIMAL, + DATETIME, + DOUBLE, + FLOAT, + INTEGER, + LONGBLOB, + LONGTEXT, + MEDIUMBLOB, + MEDIUMINT, + MEDIUMTEXT, + NCHAR, + NUMERIC, + NVARCHAR, + REAL, + SMALLINT, + TEXT, + TIME, + TIMESTAMP, + TINYBLOB, + TINYINT, + TINYTEXT, + VARCHAR, + YEAR, +) +from .types import ( + _StringType, + _IntegerType, + _NumericType, + _FloatType, + _MatchType, +) from .enumerated import ENUM, SET from .json import JSON, JSONIndexType, JSONPathType RESERVED_WORDS = set( - ['accessible', 'add', 'all', 'alter', 'analyze', 'and', 'as', 'asc', - 'asensitive', 'before', 'between', 'bigint', 'binary', 'blob', 'both', - 'by', 'call', 'cascade', 'case', 'change', 'char', 'character', 'check', - 'collate', 'column', 'condition', 'constraint', 'continue', 'convert', - 'create', 'cross', 'current_date', 'current_time', 'current_timestamp', - 'current_user', 'cursor', 'database', 'databases', 'day_hour', - 'day_microsecond', 'day_minute', 'day_second', 'dec', 'decimal', - 'declare', 'default', 'delayed', 'delete', 'desc', 'describe', - 'deterministic', 'distinct', 'distinctrow', 'div', 'double', 'drop', - 'dual', 'each', 'else', 'elseif', 'enclosed', 'escaped', 'exists', - 'exit', 'explain', 'false', 'fetch', 'float', 'float4', 'float8', - 'for', 'force', 'foreign', 'from', 'fulltext', 'grant', 'group', - 'having', 'high_priority', 'hour_microsecond', 'hour_minute', - 'hour_second', 'if', 'ignore', 'in', 'index', 'infile', 'inner', 'inout', - 'insensitive', 'insert', 'int', 'int1', 'int2', 'int3', 'int4', 'int8', - 'integer', 'interval', 'into', 'is', 'iterate', 'join', 'key', 'keys', - 'kill', 'leading', 'leave', 'left', 'like', 'limit', 'linear', 'lines', - 'load', 'localtime', 'localtimestamp', 'lock', 'long', 'longblob', - 'longtext', 'loop', 'low_priority', 'master_ssl_verify_server_cert', - 'match', 'mediumblob', 'mediumint', 'mediumtext', 'middleint', - 'minute_microsecond', 'minute_second', 'mod', 'modifies', 'natural', - 'not', 'no_write_to_binlog', 'null', 'numeric', 'on', 'optimize', - 'option', 'optionally', 'or', 'order', 'out', 'outer', 'outfile', - 'precision', 'primary', 'procedure', 'purge', 'range', 'read', 'reads', - 'read_only', 'read_write', 'real', 'references', 'regexp', 'release', - 'rename', 'repeat', 'replace', 'require', 'restrict', 'return', - 'revoke', 'right', 'rlike', 'schema', 'schemas', 'second_microsecond', - 'select', 'sensitive', 'separator', 'set', 'show', 'smallint', 'spatial', - 'specific', 'sql', 'sqlexception', 'sqlstate', 'sqlwarning', - 'sql_big_result', 'sql_calc_found_rows', 'sql_small_result', 'ssl', - 'starting', 'straight_join', 'table', 'terminated', 'then', 'tinyblob', - 'tinyint', 'tinytext', 'to', 'trailing', 'trigger', 'true', 'undo', - 'union', 'unique', 'unlock', 'unsigned', 'update', 'usage', 'use', - 'using', 'utc_date', 'utc_time', 'utc_timestamp', 'values', 'varbinary', - 'varchar', 'varcharacter', 'varying', 'when', 'where', 'while', 'with', - - 'write', 'x509', 'xor', 'year_month', 'zerofill', # 5.0 - - 'columns', 'fields', 'privileges', 'soname', 'tables', # 4.1 - - 'accessible', 'linear', 'master_ssl_verify_server_cert', 'range', - 'read_only', 'read_write', # 5.1 - - 'general', 'ignore_server_ids', 'master_heartbeat_period', 'maxvalue', - 'resignal', 'signal', 'slow', # 5.5 - - 'get', 'io_after_gtids', 'io_before_gtids', 'master_bind', 'one_shot', - 'partition', 'sql_after_gtids', 'sql_before_gtids', # 5.6 - - 'generated', 'optimizer_costs', 'stored', 'virtual', # 5.7 - - 'admin', 'cume_dist', 'empty', 'except', 'first_value', 'grouping', - 'function', 'groups', 'json_table', 'last_value', 'nth_value', - 'ntile', 'of', 'over', 'percent_rank', 'persist', 'persist_only', - 'rank', 'recursive', 'role', 'row', 'rows', 'row_number', 'system', - 'window', # 8.0 - ]) + [ + "accessible", + "add", + "all", + "alter", + "analyze", + "and", + "as", + "asc", + "asensitive", + "before", + "between", + "bigint", + "binary", + "blob", + "both", + "by", + "call", + "cascade", + "case", + "change", + "char", + "character", + "check", + "collate", + "column", + "condition", + "constraint", + "continue", + "convert", + "create", + "cross", + "current_date", + "current_time", + "current_timestamp", + "current_user", + "cursor", + "database", + "databases", + "day_hour", + "day_microsecond", + "day_minute", + "day_second", + "dec", + "decimal", + "declare", + "default", + "delayed", + "delete", + "desc", + "describe", + "deterministic", + "distinct", + "distinctrow", + "div", + "double", + "drop", + "dual", + "each", + "else", + "elseif", + "enclosed", + "escaped", + "exists", + "exit", + "explain", + "false", + "fetch", + "float", + "float4", + "float8", + "for", + "force", + "foreign", + "from", + "fulltext", + "grant", + "group", + "having", + "high_priority", + "hour_microsecond", + "hour_minute", + "hour_second", + "if", + "ignore", + "in", + "index", + "infile", + "inner", + "inout", + "insensitive", + "insert", + "int", + "int1", + "int2", + "int3", + "int4", + "int8", + "integer", + "interval", + "into", + "is", + "iterate", + "join", + "key", + "keys", + "kill", + "leading", + "leave", + "left", + "like", + "limit", + "linear", + "lines", + "load", + "localtime", + "localtimestamp", + "lock", + "long", + "longblob", + "longtext", + "loop", + "low_priority", + "master_ssl_verify_server_cert", + "match", + "mediumblob", + "mediumint", + "mediumtext", + "middleint", + "minute_microsecond", + "minute_second", + "mod", + "modifies", + "natural", + "not", + "no_write_to_binlog", + "null", + "numeric", + "on", + "optimize", + "option", + "optionally", + "or", + "order", + "out", + "outer", + "outfile", + "precision", + "primary", + "procedure", + "purge", + "range", + "read", + "reads", + "read_only", + "read_write", + "real", + "references", + "regexp", + "release", + "rename", + "repeat", + "replace", + "require", + "restrict", + "return", + "revoke", + "right", + "rlike", + "schema", + "schemas", + "second_microsecond", + "select", + "sensitive", + "separator", + "set", + "show", + "smallint", + "spatial", + "specific", + "sql", + "sqlexception", + "sqlstate", + "sqlwarning", + "sql_big_result", + "sql_calc_found_rows", + "sql_small_result", + "ssl", + "starting", + "straight_join", + "table", + "terminated", + "then", + "tinyblob", + "tinyint", + "tinytext", + "to", + "trailing", + "trigger", + "true", + "undo", + "union", + "unique", + "unlock", + "unsigned", + "update", + "usage", + "use", + "using", + "utc_date", + "utc_time", + "utc_timestamp", + "values", + "varbinary", + "varchar", + "varcharacter", + "varying", + "when", + "where", + "while", + "with", + "write", + "x509", + "xor", + "year_month", + "zerofill", # 5.0 + "columns", + "fields", + "privileges", + "soname", + "tables", # 4.1 + "accessible", + "linear", + "master_ssl_verify_server_cert", + "range", + "read_only", + "read_write", # 5.1 + "general", + "ignore_server_ids", + "master_heartbeat_period", + "maxvalue", + "resignal", + "signal", + "slow", # 5.5 + "get", + "io_after_gtids", + "io_before_gtids", + "master_bind", + "one_shot", + "partition", + "sql_after_gtids", + "sql_before_gtids", # 5.6 + "generated", + "optimizer_costs", + "stored", + "virtual", # 5.7 + "admin", + "cume_dist", + "empty", + "except", + "first_value", + "grouping", + "function", + "groups", + "json_table", + "last_value", + "nth_value", + "ntile", + "of", + "over", + "percent_rank", + "persist", + "persist_only", + "rank", + "recursive", + "role", + "row", + "rows", + "row_number", + "system", + "window", # 8.0 + ] +) AUTOCOMMIT_RE = re.compile( - r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)', - re.I | re.UNICODE) + r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|LOAD +DATA|REPLACE)", + re.I | re.UNICODE, +) SET_RE = re.compile( - r'\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w', - re.I | re.UNICODE) + r"\s*SET\s+(?:(?:GLOBAL|SESSION)\s+)?\w", re.I | re.UNICODE +) # old names @@ -870,52 +1125,50 @@ colspecs = { sqltypes.MatchType: _MatchType, sqltypes.JSON: JSON, sqltypes.JSON.JSONIndexType: JSONIndexType, - sqltypes.JSON.JSONPathType: JSONPathType - + sqltypes.JSON.JSONPathType: JSONPathType, } # Everything 3.23 through 5.1 excepting OpenGIS types. ischema_names = { - 'bigint': BIGINT, - 'binary': BINARY, - 'bit': BIT, - 'blob': BLOB, - 'boolean': BOOLEAN, - 'char': CHAR, - 'date': DATE, - 'datetime': DATETIME, - 'decimal': DECIMAL, - 'double': DOUBLE, - 'enum': ENUM, - 'fixed': DECIMAL, - 'float': FLOAT, - 'int': INTEGER, - 'integer': INTEGER, - 'json': JSON, - 'longblob': LONGBLOB, - 'longtext': LONGTEXT, - 'mediumblob': MEDIUMBLOB, - 'mediumint': MEDIUMINT, - 'mediumtext': MEDIUMTEXT, - 'nchar': NCHAR, - 'nvarchar': NVARCHAR, - 'numeric': NUMERIC, - 'set': SET, - 'smallint': SMALLINT, - 'text': TEXT, - 'time': TIME, - 'timestamp': TIMESTAMP, - 'tinyblob': TINYBLOB, - 'tinyint': TINYINT, - 'tinytext': TINYTEXT, - 'varbinary': VARBINARY, - 'varchar': VARCHAR, - 'year': YEAR, + "bigint": BIGINT, + "binary": BINARY, + "bit": BIT, + "blob": BLOB, + "boolean": BOOLEAN, + "char": CHAR, + "date": DATE, + "datetime": DATETIME, + "decimal": DECIMAL, + "double": DOUBLE, + "enum": ENUM, + "fixed": DECIMAL, + "float": FLOAT, + "int": INTEGER, + "integer": INTEGER, + "json": JSON, + "longblob": LONGBLOB, + "longtext": LONGTEXT, + "mediumblob": MEDIUMBLOB, + "mediumint": MEDIUMINT, + "mediumtext": MEDIUMTEXT, + "nchar": NCHAR, + "nvarchar": NVARCHAR, + "numeric": NUMERIC, + "set": SET, + "smallint": SMALLINT, + "text": TEXT, + "time": TIME, + "timestamp": TIMESTAMP, + "tinyblob": TINYBLOB, + "tinyint": TINYINT, + "tinytext": TINYTEXT, + "varbinary": VARBINARY, + "varchar": VARCHAR, + "year": YEAR, } class MySQLExecutionContext(default.DefaultExecutionContext): - def should_autocommit_text(self, statement): return AUTOCOMMIT_RE.match(statement) @@ -932,7 +1185,7 @@ class MySQLCompiler(compiler.SQLCompiler): """Overridden from base SQLCompiler value""" extract_map = compiler.SQLCompiler.extract_map.copy() - extract_map.update({'milliseconds': 'millisecond'}) + extract_map.update({"milliseconds": "millisecond"}) def visit_random_func(self, fn, **kw): return "rand%s" % self.function_argspec(fn) @@ -943,12 +1196,14 @@ class MySQLCompiler(compiler.SQLCompiler): def visit_json_getitem_op_binary(self, binary, operator, **kw): return "JSON_EXTRACT(%s, %s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_json_path_getitem_op_binary(self, binary, operator, **kw): return "JSON_EXTRACT(%s, %s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_on_duplicate_key_update(self, on_duplicate, **kw): if on_duplicate._parameter_ordering: @@ -958,7 +1213,8 @@ class MySQLCompiler(compiler.SQLCompiler): ] ordered_keys = set(parameter_ordering) cols = [ - self.statement.table.c[key] for key in parameter_ordering + self.statement.table.c[key] + for key in parameter_ordering if key in self.statement.table.c ] + [ c for c in self.statement.table.c if c.key not in ordered_keys @@ -979,9 +1235,11 @@ class MySQLCompiler(compiler.SQLCompiler): val = val._clone() val.type = column.type value_text = self.process(val.self_group(), use_schema=False) - elif isinstance(val, elements.ColumnClause) \ - and val.table is on_duplicate.inserted_alias: - value_text = 'VALUES(' + self.preparer.quote(column.name) + ')' + elif ( + isinstance(val, elements.ColumnClause) + and val.table is on_duplicate.inserted_alias + ): + value_text = "VALUES(" + self.preparer.quote(column.name) + ")" else: value_text = self.process(val.self_group(), use_schema=False) name_text = self.preparer.quote(column.name) @@ -990,22 +1248,27 @@ class MySQLCompiler(compiler.SQLCompiler): non_matching = set(on_duplicate.update) - set(c.key for c in cols) if non_matching: util.warn( - 'Additional column names not matching ' - "any column keys in table '%s': %s" % ( + "Additional column names not matching " + "any column keys in table '%s': %s" + % ( self.statement.table.name, - (', '.join("'%s'" % c for c in non_matching)) + (", ".join("'%s'" % c for c in non_matching)), ) ) - return 'ON DUPLICATE KEY UPDATE ' + ', '.join(clauses) + return "ON DUPLICATE KEY UPDATE " + ", ".join(clauses) def visit_concat_op_binary(self, binary, operator, **kw): - return "concat(%s, %s)" % (self.process(binary.left, **kw), - self.process(binary.right, **kw)) + return "concat(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) def visit_match_op_binary(self, binary, operator, **kw): - return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % \ - (self.process(binary.left, **kw), self.process(binary.right, **kw)) + return "MATCH (%s) AGAINST (%s IN BOOLEAN MODE)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) def get_from_hint_text(self, table, text): return text @@ -1016,26 +1279,35 @@ class MySQLCompiler(compiler.SQLCompiler): if isinstance(type_, sqltypes.TypeDecorator): return self.visit_typeclause(typeclause, type_.impl, **kw) elif isinstance(type_, sqltypes.Integer): - if getattr(type_, 'unsigned', False): - return 'UNSIGNED INTEGER' + if getattr(type_, "unsigned", False): + return "UNSIGNED INTEGER" else: - return 'SIGNED INTEGER' + return "SIGNED INTEGER" elif isinstance(type_, sqltypes.TIMESTAMP): - return 'DATETIME' - elif isinstance(type_, (sqltypes.DECIMAL, sqltypes.DateTime, - sqltypes.Date, sqltypes.Time)): + return "DATETIME" + elif isinstance( + type_, + ( + sqltypes.DECIMAL, + sqltypes.DateTime, + sqltypes.Date, + sqltypes.Time, + ), + ): return self.dialect.type_compiler.process(type_) - elif isinstance(type_, sqltypes.String) \ - and not isinstance(type_, (ENUM, SET)): + elif isinstance(type_, sqltypes.String) and not isinstance( + type_, (ENUM, SET) + ): adapted = CHAR._adapt_string_for_cast(type_) return self.dialect.type_compiler.process(adapted) elif isinstance(type_, sqltypes._Binary): - return 'BINARY' + return "BINARY" elif isinstance(type_, sqltypes.JSON): return "JSON" elif isinstance(type_, sqltypes.NUMERIC): - return self.dialect.type_compiler.process( - type_).replace('NUMERIC', 'DECIMAL') + return self.dialect.type_compiler.process(type_).replace( + "NUMERIC", "DECIMAL" + ) else: return None @@ -1044,23 +1316,25 @@ class MySQLCompiler(compiler.SQLCompiler): if not self.dialect._supports_cast: util.warn( "Current MySQL version does not support " - "CAST; the CAST will be skipped.") + "CAST; the CAST will be skipped." + ) return self.process(cast.clause.self_group(), **kw) type_ = self.process(cast.typeclause) if type_ is None: util.warn( "Datatype %s does not support CAST on MySQL; " - "the CAST will be skipped." % - self.dialect.type_compiler.process(cast.typeclause.type)) + "the CAST will be skipped." + % self.dialect.type_compiler.process(cast.typeclause.type) + ) return self.process(cast.clause.self_group(), **kw) - return 'CAST(%s AS %s)' % (self.process(cast.clause, **kw), type_) + return "CAST(%s AS %s)" % (self.process(cast.clause, **kw), type_) def render_literal_value(self, value, type_): value = super(MySQLCompiler, self).render_literal_value(value, type_) if self.dialect._backslash_escapes: - value = value.replace('\\', '\\\\') + value = value.replace("\\", "\\\\") return value # override native_boolean=False behavior here, as @@ -1096,12 +1370,15 @@ class MySQLCompiler(compiler.SQLCompiler): else: join_type = " INNER JOIN " - return ''.join( - (self.process(join.left, asfrom=True, **kwargs), - join_type, - self.process(join.right, asfrom=True, **kwargs), - " ON ", - self.process(join.onclause, **kwargs))) + return "".join( + ( + self.process(join.left, asfrom=True, **kwargs), + join_type, + self.process(join.right, asfrom=True, **kwargs), + " ON ", + self.process(join.onclause, **kwargs), + ) + ) def for_update_clause(self, select, **kw): if select._for_update_arg.read: @@ -1118,11 +1395,13 @@ class MySQLCompiler(compiler.SQLCompiler): # The latter is more readable for offsets but we're stuck with the # former until we can refine dialects by server revision. - limit_clause, offset_clause = select._limit_clause, \ - select._offset_clause + limit_clause, offset_clause = ( + select._limit_clause, + select._offset_clause, + ) if limit_clause is None and offset_clause is None: - return '' + return "" elif offset_clause is not None: # As suggested by the MySQL docs, need to apply an # artificial limit if one wasn't provided @@ -1134,35 +1413,38 @@ class MySQLCompiler(compiler.SQLCompiler): # but also is consistent with the usage of the upper # bound as part of MySQL's "syntax" for OFFSET with # no LIMIT - return ' \n LIMIT %s, %s' % ( + return " \n LIMIT %s, %s" % ( self.process(offset_clause, **kw), - "18446744073709551615") + "18446744073709551615", + ) else: - return ' \n LIMIT %s, %s' % ( + return " \n LIMIT %s, %s" % ( self.process(offset_clause, **kw), - self.process(limit_clause, **kw)) + self.process(limit_clause, **kw), + ) else: # No offset provided, so just use the limit - return ' \n LIMIT %s' % (self.process(limit_clause, **kw),) + return " \n LIMIT %s" % (self.process(limit_clause, **kw),) def update_limit_clause(self, update_stmt): - limit = update_stmt.kwargs.get('%s_limit' % self.dialect.name, None) + limit = update_stmt.kwargs.get("%s_limit" % self.dialect.name, None) if limit: return "LIMIT %s" % limit else: return None - def update_tables_clause(self, update_stmt, from_table, - extra_froms, **kw): - return ', '.join(t._compiler_dispatch(self, asfrom=True, **kw) - for t in [from_table] + list(extra_froms)) + def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw): + return ", ".join( + t._compiler_dispatch(self, asfrom=True, **kw) + for t in [from_table] + list(extra_froms) + ) - def update_from_clause(self, update_stmt, from_table, - extra_froms, from_hints, **kw): + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): return None - def delete_table_clause(self, delete_stmt, from_table, - extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms): """If we have extra froms make sure we render any alias as hint.""" ashint = False if extra_froms: @@ -1171,24 +1453,27 @@ class MySQLCompiler(compiler.SQLCompiler): self, asfrom=True, iscrud=True, ashint=ashint ) - def delete_extra_from_clause(self, delete_stmt, from_table, - extra_froms, from_hints, **kw): + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the DELETE .. USING clause specific to MySQL.""" - return "USING " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in [from_table] + extra_froms) + return "USING " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) def visit_empty_set_expr(self, element_types): return ( "SELECT %(outer)s FROM (SELECT %(inner)s) " - "as _empty_set WHERE 1!=1" % { + "as _empty_set WHERE 1!=1" + % { "inner": ", ".join( "1 AS _in_%s" % idx - for idx, type_ in enumerate(element_types)), + for idx, type_ in enumerate(element_types) + ), "outer": ", ".join( - "_in_%s" % idx - for idx, type_ in enumerate(element_types)) + "_in_%s" % idx for idx, type_ in enumerate(element_types) + ), } ) @@ -1200,35 +1485,39 @@ class MySQLDDLCompiler(compiler.DDLCompiler): colspec = [ self.preparer.format_column(column), self.dialect.type_compiler.process( - column.type, type_expression=column) + column.type, type_expression=column + ), ] is_timestamp = isinstance(column.type, sqltypes.TIMESTAMP) if not column.nullable: - colspec.append('NOT NULL') + colspec.append("NOT NULL") # see: http://docs.sqlalchemy.org/en/latest/dialects/ # mysql.html#mysql_timestamp_null elif column.nullable and is_timestamp: - colspec.append('NULL') + colspec.append("NULL") default = self.get_column_default_string(column) if default is not None: - colspec.append('DEFAULT ' + default) + colspec.append("DEFAULT " + default) comment = column.comment if comment is not None: literal = self.sql_compiler.render_literal_value( - comment, sqltypes.String()) - colspec.append('COMMENT ' + literal) + comment, sqltypes.String() + ) + colspec.append("COMMENT " + literal) - if column.table is not None \ - and column is column.table._autoincrement_column and \ - column.server_default is None: - colspec.append('AUTO_INCREMENT') + if ( + column.table is not None + and column is column.table._autoincrement_column + and column.server_default is None + ): + colspec.append("AUTO_INCREMENT") - return ' '.join(colspec) + return " ".join(colspec) def post_create_table(self, table): """Build table-level CREATE options like ENGINE and COLLATE.""" @@ -1236,76 +1525,94 @@ class MySQLDDLCompiler(compiler.DDLCompiler): table_opts = [] opts = dict( - ( - k[len(self.dialect.name) + 1:].upper(), - v - ) + (k[len(self.dialect.name) + 1 :].upper(), v) for k, v in table.kwargs.items() - if k.startswith('%s_' % self.dialect.name) + if k.startswith("%s_" % self.dialect.name) ) if table.comment is not None: - opts['COMMENT'] = table.comment + opts["COMMENT"] = table.comment partition_options = [ - 'PARTITION_BY', 'PARTITIONS', 'SUBPARTITIONS', - 'SUBPARTITION_BY' + "PARTITION_BY", + "PARTITIONS", + "SUBPARTITIONS", + "SUBPARTITION_BY", ] nonpart_options = set(opts).difference(partition_options) part_options = set(opts).intersection(partition_options) - for opt in topological.sort([ - ('DEFAULT_CHARSET', 'COLLATE'), - ('DEFAULT_CHARACTER_SET', 'COLLATE'), - ], nonpart_options): + for opt in topological.sort( + [ + ("DEFAULT_CHARSET", "COLLATE"), + ("DEFAULT_CHARACTER_SET", "COLLATE"), + ], + nonpart_options, + ): arg = opts[opt] if opt in _reflection._options_of_type_string: arg = self.sql_compiler.render_literal_value( - arg, sqltypes.String()) - - if opt in ('DATA_DIRECTORY', 'INDEX_DIRECTORY', - 'DEFAULT_CHARACTER_SET', 'CHARACTER_SET', - 'DEFAULT_CHARSET', - 'DEFAULT_COLLATE'): - opt = opt.replace('_', ' ') + arg, sqltypes.String() + ) - joiner = '=' - if opt in ('TABLESPACE', 'DEFAULT CHARACTER SET', - 'CHARACTER SET', 'COLLATE'): - joiner = ' ' + if opt in ( + "DATA_DIRECTORY", + "INDEX_DIRECTORY", + "DEFAULT_CHARACTER_SET", + "CHARACTER_SET", + "DEFAULT_CHARSET", + "DEFAULT_COLLATE", + ): + opt = opt.replace("_", " ") + + joiner = "=" + if opt in ( + "TABLESPACE", + "DEFAULT CHARACTER SET", + "CHARACTER SET", + "COLLATE", + ): + joiner = " " table_opts.append(joiner.join((opt, arg))) - for opt in topological.sort([ - ('PARTITION_BY', 'PARTITIONS'), - ('PARTITION_BY', 'SUBPARTITION_BY'), - ('PARTITION_BY', 'SUBPARTITIONS'), - ('PARTITIONS', 'SUBPARTITIONS'), - ('PARTITIONS', 'SUBPARTITION_BY'), - ('SUBPARTITION_BY', 'SUBPARTITIONS') - ], part_options): + for opt in topological.sort( + [ + ("PARTITION_BY", "PARTITIONS"), + ("PARTITION_BY", "SUBPARTITION_BY"), + ("PARTITION_BY", "SUBPARTITIONS"), + ("PARTITIONS", "SUBPARTITIONS"), + ("PARTITIONS", "SUBPARTITION_BY"), + ("SUBPARTITION_BY", "SUBPARTITIONS"), + ], + part_options, + ): arg = opts[opt] if opt in _reflection._options_of_type_string: arg = self.sql_compiler.render_literal_value( - arg, sqltypes.String()) + arg, sqltypes.String() + ) - opt = opt.replace('_', ' ') - joiner = ' ' + opt = opt.replace("_", " ") + joiner = " " table_opts.append(joiner.join((opt, arg))) - return ' '.join(table_opts) + return " ".join(table_opts) def visit_create_index(self, create, **kw): index = create.element self._verify_index_table(index) preparer = self.preparer table = preparer.format_table(index.table) - columns = [self.sql_compiler.process(expr, include_table=False, - literal_binds=True) - for expr in index.expressions] + columns = [ + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ] name = self._prepared_index_name(index) @@ -1313,53 +1620,54 @@ class MySQLDDLCompiler(compiler.DDLCompiler): if index.unique: text += "UNIQUE " - index_prefix = index.kwargs.get('mysql_prefix', None) + index_prefix = index.kwargs.get("mysql_prefix", None) if index_prefix: - text += index_prefix + ' ' + text += index_prefix + " " text += "INDEX %s ON %s " % (name, table) - length = index.dialect_options['mysql']['length'] + length = index.dialect_options["mysql"]["length"] if length is not None: if isinstance(length, dict): # length value can be a (column_name --> integer value) # mapping specifying the prefix length for each column of the # index - columns = ', '.join( - '%s(%d)' % (expr, length[col.name]) if col.name in length - else - ( - '%s(%d)' % (expr, length[expr]) if expr in length - else '%s' % expr + columns = ", ".join( + "%s(%d)" % (expr, length[col.name]) + if col.name in length + else ( + "%s(%d)" % (expr, length[expr]) + if expr in length + else "%s" % expr ) for col, expr in zip(index.expressions, columns) ) else: # or can be an integer value specifying the same # prefix length for all columns of the index - columns = ', '.join( - '%s(%d)' % (col, length) - for col in columns + columns = ", ".join( + "%s(%d)" % (col, length) for col in columns ) else: - columns = ', '.join(columns) - text += '(%s)' % columns + columns = ", ".join(columns) + text += "(%s)" % columns - parser = index.dialect_options['mysql']['with_parser'] + parser = index.dialect_options["mysql"]["with_parser"] if parser is not None: - text += " WITH PARSER %s" % (parser, ) + text += " WITH PARSER %s" % (parser,) - using = index.dialect_options['mysql']['using'] + using = index.dialect_options["mysql"]["using"] if using is not None: text += " USING %s" % (preparer.quote(using)) return text def visit_primary_key_constraint(self, constraint): - text = super(MySQLDDLCompiler, self).\ - visit_primary_key_constraint(constraint) - using = constraint.dialect_options['mysql']['using'] + text = super(MySQLDDLCompiler, self).visit_primary_key_constraint( + constraint + ) + using = constraint.dialect_options["mysql"]["using"] if using: text += " USING %s" % (self.preparer.quote(using)) return text @@ -1368,9 +1676,9 @@ class MySQLDDLCompiler(compiler.DDLCompiler): index = drop.element return "\nDROP INDEX %s ON %s" % ( - self._prepared_index_name(index, - include_schema=False), - self.preparer.format_table(index.table)) + self._prepared_index_name(index, include_schema=False), + self.preparer.format_table(index.table), + ) def visit_drop_constraint(self, drop): constraint = drop.element @@ -1386,29 +1694,33 @@ class MySQLDDLCompiler(compiler.DDLCompiler): else: qual = "" const = self.preparer.format_constraint(constraint) - return "ALTER TABLE %s DROP %s%s" % \ - (self.preparer.format_table(constraint.table), - qual, const) + return "ALTER TABLE %s DROP %s%s" % ( + self.preparer.format_table(constraint.table), + qual, + const, + ) def define_constraint_match(self, constraint): if constraint.match is not None: raise exc.CompileError( "MySQL ignores the 'MATCH' keyword while at the same time " - "causes ON UPDATE/ON DELETE clauses to be ignored.") + "causes ON UPDATE/ON DELETE clauses to be ignored." + ) return "" def visit_set_table_comment(self, create): return "ALTER TABLE %s COMMENT %s" % ( self.preparer.format_table(create.element), self.sql_compiler.render_literal_value( - create.element.comment, sqltypes.String()) + create.element.comment, sqltypes.String() + ), ) def visit_set_column_comment(self, create): return "ALTER TABLE %s CHANGE %s %s" % ( self.preparer.format_table(create.element.table), self.preparer.format_column(create.element), - self.get_column_specification(create.element) + self.get_column_specification(create.element), ) @@ -1420,9 +1732,9 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return spec if type_.unsigned: - spec += ' UNSIGNED' + spec += " UNSIGNED" if type_.zerofill: - spec += ' ZEROFILL' + spec += " ZEROFILL" return spec def _extend_string(self, type_, defaults, spec): @@ -1434,28 +1746,30 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def attr(name): return getattr(type_, name, defaults.get(name)) - if attr('charset'): - charset = 'CHARACTER SET %s' % attr('charset') - elif attr('ascii'): - charset = 'ASCII' - elif attr('unicode'): - charset = 'UNICODE' + if attr("charset"): + charset = "CHARACTER SET %s" % attr("charset") + elif attr("ascii"): + charset = "ASCII" + elif attr("unicode"): + charset = "UNICODE" else: charset = None - if attr('collation'): - collation = 'COLLATE %s' % type_.collation - elif attr('binary'): - collation = 'BINARY' + if attr("collation"): + collation = "COLLATE %s" % type_.collation + elif attr("binary"): + collation = "BINARY" else: collation = None - if attr('national'): + if attr("national"): # NATIONAL (aka NCHAR/NVARCHAR) trumps charsets. - return ' '.join([c for c in ('NATIONAL', spec, collation) - if c is not None]) - return ' '.join([c for c in (spec, charset, collation) - if c is not None]) + return " ".join( + [c for c in ("NATIONAL", spec, collation) if c is not None] + ) + return " ".join( + [c for c in (spec, charset, collation) if c is not None] + ) def _mysql_type(self, type_): return isinstance(type_, (_StringType, _NumericType)) @@ -1464,95 +1778,113 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): if type_.precision is None: return self._extend_numeric(type_, "NUMERIC") elif type_.scale is None: - return self._extend_numeric(type_, - "NUMERIC(%(precision)s)" % - {'precision': type_.precision}) + return self._extend_numeric( + type_, + "NUMERIC(%(precision)s)" % {"precision": type_.precision}, + ) else: - return self._extend_numeric(type_, - "NUMERIC(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + return self._extend_numeric( + type_, + "NUMERIC(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) def visit_DECIMAL(self, type_, **kw): if type_.precision is None: return self._extend_numeric(type_, "DECIMAL") elif type_.scale is None: - return self._extend_numeric(type_, - "DECIMAL(%(precision)s)" % - {'precision': type_.precision}) + return self._extend_numeric( + type_, + "DECIMAL(%(precision)s)" % {"precision": type_.precision}, + ) else: - return self._extend_numeric(type_, - "DECIMAL(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + return self._extend_numeric( + type_, + "DECIMAL(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) def visit_DOUBLE(self, type_, **kw): if type_.precision is not None and type_.scale is not None: - return self._extend_numeric(type_, - "DOUBLE(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + return self._extend_numeric( + type_, + "DOUBLE(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) else: - return self._extend_numeric(type_, 'DOUBLE') + return self._extend_numeric(type_, "DOUBLE") def visit_REAL(self, type_, **kw): if type_.precision is not None and type_.scale is not None: - return self._extend_numeric(type_, - "REAL(%(precision)s, %(scale)s)" % - {'precision': type_.precision, - 'scale': type_.scale}) + return self._extend_numeric( + type_, + "REAL(%(precision)s, %(scale)s)" + % {"precision": type_.precision, "scale": type_.scale}, + ) else: - return self._extend_numeric(type_, 'REAL') + return self._extend_numeric(type_, "REAL") def visit_FLOAT(self, type_, **kw): - if self._mysql_type(type_) and \ - type_.scale is not None and \ - type_.precision is not None: + if ( + self._mysql_type(type_) + and type_.scale is not None + and type_.precision is not None + ): return self._extend_numeric( - type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale)) + type_, "FLOAT(%s, %s)" % (type_.precision, type_.scale) + ) elif type_.precision is not None: - return self._extend_numeric(type_, - "FLOAT(%s)" % (type_.precision,)) + return self._extend_numeric( + type_, "FLOAT(%s)" % (type_.precision,) + ) else: return self._extend_numeric(type_, "FLOAT") def visit_INTEGER(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( - type_, "INTEGER(%(display_width)s)" % - {'display_width': type_.display_width}) + type_, + "INTEGER(%(display_width)s)" + % {"display_width": type_.display_width}, + ) else: return self._extend_numeric(type_, "INTEGER") def visit_BIGINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( - type_, "BIGINT(%(display_width)s)" % - {'display_width': type_.display_width}) + type_, + "BIGINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) else: return self._extend_numeric(type_, "BIGINT") def visit_MEDIUMINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: return self._extend_numeric( - type_, "MEDIUMINT(%(display_width)s)" % - {'display_width': type_.display_width}) + type_, + "MEDIUMINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) else: return self._extend_numeric(type_, "MEDIUMINT") def visit_TINYINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: - return self._extend_numeric(type_, - "TINYINT(%s)" % type_.display_width) + return self._extend_numeric( + type_, "TINYINT(%s)" % type_.display_width + ) else: return self._extend_numeric(type_, "TINYINT") def visit_SMALLINT(self, type_, **kw): if self._mysql_type(type_) and type_.display_width is not None: - return self._extend_numeric(type_, - "SMALLINT(%(display_width)s)" % - {'display_width': type_.display_width} - ) + return self._extend_numeric( + type_, + "SMALLINT(%(display_width)s)" + % {"display_width": type_.display_width}, + ) else: return self._extend_numeric(type_, "SMALLINT") @@ -1563,7 +1895,7 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return "BIT" def visit_DATETIME(self, type_, **kw): - if getattr(type_, 'fsp', None): + if getattr(type_, "fsp", None): return "DATETIME(%d)" % type_.fsp else: return "DATETIME" @@ -1572,13 +1904,13 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): return "DATE" def visit_TIME(self, type_, **kw): - if getattr(type_, 'fsp', None): + if getattr(type_, "fsp", None): return "TIME(%d)" % type_.fsp else: return "TIME" def visit_TIMESTAMP(self, type_, **kw): - if getattr(type_, 'fsp', None): + if getattr(type_, "fsp", None): return "TIMESTAMP(%d)" % type_.fsp else: return "TIMESTAMP" @@ -1606,17 +1938,17 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): def visit_VARCHAR(self, type_, **kw): if type_.length: - return self._extend_string( - type_, {}, "VARCHAR(%d)" % type_.length) + return self._extend_string(type_, {}, "VARCHAR(%d)" % type_.length) else: raise exc.CompileError( - "VARCHAR requires a length on dialect %s" % - self.dialect.name) + "VARCHAR requires a length on dialect %s" % self.dialect.name + ) def visit_CHAR(self, type_, **kw): if type_.length: - return self._extend_string(type_, {}, "CHAR(%(length)s)" % - {'length': type_.length}) + return self._extend_string( + type_, {}, "CHAR(%(length)s)" % {"length": type_.length} + ) else: return self._extend_string(type_, {}, "CHAR") @@ -1625,22 +1957,26 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): # of "NVARCHAR". if type_.length: return self._extend_string( - type_, {'national': True}, - "VARCHAR(%(length)s)" % {'length': type_.length}) + type_, + {"national": True}, + "VARCHAR(%(length)s)" % {"length": type_.length}, + ) else: raise exc.CompileError( - "NVARCHAR requires a length on dialect %s" % - self.dialect.name) + "NVARCHAR requires a length on dialect %s" % self.dialect.name + ) def visit_NCHAR(self, type_, **kw): # We'll actually generate the equiv. # "NATIONAL CHAR" instead of "NCHAR". if type_.length: return self._extend_string( - type_, {'national': True}, - "CHAR(%(length)s)" % {'length': type_.length}) + type_, + {"national": True}, + "CHAR(%(length)s)" % {"length": type_.length}, + ) else: - return self._extend_string(type_, {'national': True}, "CHAR") + return self._extend_string(type_, {"national": True}, "CHAR") def visit_VARBINARY(self, type_, **kw): return "VARBINARY(%d)" % type_.length @@ -1676,17 +2012,19 @@ class MySQLTypeCompiler(compiler.GenericTypeCompiler): quoted_enums = [] for e in enumerated_values: quoted_enums.append("'%s'" % e.replace("'", "''")) - return self._extend_string(type_, {}, "%s(%s)" % ( - name, ",".join(quoted_enums)) + return self._extend_string( + type_, {}, "%s(%s)" % (name, ",".join(quoted_enums)) ) def visit_ENUM(self, type_, **kw): - return self._visit_enumerated_values("ENUM", type_, - type_._enumerated_values) + return self._visit_enumerated_values( + "ENUM", type_, type_._enumerated_values + ) def visit_SET(self, type_, **kw): - return self._visit_enumerated_values("SET", type_, - type_._enumerated_values) + return self._visit_enumerated_values( + "SET", type_, type_._enumerated_values + ) def visit_BOOLEAN(self, type, **kw): return "BOOL" @@ -1703,9 +2041,8 @@ class MySQLIdentifierPreparer(compiler.IdentifierPreparer): quote = '"' super(MySQLIdentifierPreparer, self).__init__( - dialect, - initial_quote=quote, - escape_quote=quote) + dialect, initial_quote=quote, escape_quote=quote + ) def _quote_free_identifiers(self, *ids): """Unilaterally identifier-quote any number of strings.""" @@ -1719,7 +2056,7 @@ class MySQLDialect(default.DefaultDialect): Not used directly in application code. """ - name = 'mysql' + name = "mysql" supports_alter = True # MySQL has no true "boolean" type; we @@ -1738,7 +2075,7 @@ class MySQLDialect(default.DefaultDialect): supports_comments = True inline_comments = True - default_paramstyle = 'format' + default_paramstyle = "format" colspecs = colspecs cte_follows_insert = True @@ -1756,26 +2093,28 @@ class MySQLDialect(default.DefaultDialect): _server_ansiquotes = False construct_arguments = [ - (sa_schema.Table, { - "*": None - }), - (sql.Update, { - "limit": None - }), - (sa_schema.PrimaryKeyConstraint, { - "using": None - }), - (sa_schema.Index, { - "using": None, - "length": None, - "prefix": None, - "with_parser": None - }) + (sa_schema.Table, {"*": None}), + (sql.Update, {"limit": None}), + (sa_schema.PrimaryKeyConstraint, {"using": None}), + ( + sa_schema.Index, + { + "using": None, + "length": None, + "prefix": None, + "with_parser": None, + }, + ), ] - def __init__(self, isolation_level=None, json_serializer=None, - json_deserializer=None, **kwargs): - kwargs.pop('use_ansiquotes', None) # legacy + def __init__( + self, + isolation_level=None, + json_serializer=None, + json_deserializer=None, + **kwargs + ): + kwargs.pop("use_ansiquotes", None) # legacy default.DefaultDialect.__init__(self, **kwargs) self.isolation_level = isolation_level self._json_serializer = json_serializer @@ -1783,22 +2122,30 @@ class MySQLDialect(default.DefaultDialect): def on_connect(self): if self.isolation_level is not None: + def connect(conn): self.set_isolation_level(conn, self.isolation_level) + return connect else: return None - _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED', - 'READ COMMITTED', 'REPEATABLE READ']) + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + ] + ) def set_isolation_level(self, connection, level): - level = level.replace('_', ' ') + level = level.replace("_", " ") # adjust for ConnectionFairy being present # allows attribute set e.g. "connection.autocommit = True" # to work properly - if hasattr(connection, 'connection'): + if hasattr(connection, "connection"): connection = connection.connection self._set_isolation_level(connection, level) @@ -1807,8 +2154,8 @@ class MySQLDialect(default.DefaultDialect): if level not in self._isolation_lookup: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) ) cursor = connection.cursor() cursor.execute("SET SESSION TRANSACTION ISOLATION LEVEL %s" % level) @@ -1818,9 +2165,9 @@ class MySQLDialect(default.DefaultDialect): def get_isolation_level(self, connection): cursor = connection.cursor() if self._is_mysql and self.server_version_info >= (5, 7, 20): - cursor.execute('SELECT @@transaction_isolation') + cursor.execute("SELECT @@transaction_isolation") else: - cursor.execute('SELECT @@tx_isolation') + cursor.execute("SELECT @@tx_isolation") val = cursor.fetchone()[0] cursor.close() if util.py3k and isinstance(val, bytes): @@ -1840,7 +2187,7 @@ class MySQLDialect(default.DefaultDialect): val = val.decode() version = [] - r = re.compile(r'[.\-]') + r = re.compile(r"[.\-]") for n in r.split(val): try: version.append(int(n)) @@ -1885,29 +2232,38 @@ class MySQLDialect(default.DefaultDialect): connection.execute(sql.text("XA END :xid"), xid=xid) connection.execute(sql.text("XA PREPARE :xid"), xid=xid) - def do_rollback_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: connection.execute(sql.text("XA END :xid"), xid=xid) connection.execute(sql.text("XA ROLLBACK :xid"), xid=xid) - def do_commit_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: self.do_prepare_twophase(connection, xid) connection.execute(sql.text("XA COMMIT :xid"), xid=xid) def do_recover_twophase(self, connection): resultset = connection.execute("XA RECOVER") - return [row['data'][0:row['gtrid_length']] for row in resultset] + return [row["data"][0 : row["gtrid_length"]] for row in resultset] def is_disconnect(self, e, connection, cursor): - if isinstance(e, (self.dbapi.OperationalError, - self.dbapi.ProgrammingError)): - return self._extract_error_code(e) in \ - (2006, 2013, 2014, 2045, 2055) + if isinstance( + e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError) + ): + return self._extract_error_code(e) in ( + 2006, + 2013, + 2014, + 2045, + 2055, + ) elif isinstance( - e, (self.dbapi.InterfaceError, self.dbapi.InternalError)): + e, (self.dbapi.InterfaceError, self.dbapi.InternalError) + ): # if underlying connection is closed, # this is the error you get return "(0, '')" in str(e) @@ -1944,7 +2300,7 @@ class MySQLDialect(default.DefaultDialect): raise NotImplementedError() def _get_default_schema_name(self, connection): - return connection.execute('SELECT DATABASE()').scalar() + return connection.execute("SELECT DATABASE()").scalar() def has_table(self, connection, table_name, schema=None): # SHOW TABLE STATUS LIKE and SHOW TABLES LIKE do not function properly @@ -1957,15 +2313,19 @@ class MySQLDialect(default.DefaultDialect): # full_name = self.identifier_preparer.format_table(table, # use_schema=True) - full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( - schema, table_name)) + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers( + schema, table_name + ) + ) st = "DESCRIBE %s" % full_name rs = None try: try: rs = connection.execution_options( - skip_user_error_events=True).execute(st) + skip_user_error_events=True + ).execute(st) have = rs.fetchone() is not None rs.close() return have @@ -1986,12 +2346,13 @@ class MySQLDialect(default.DefaultDialect): # if ansiquotes == True, build a new IdentifierPreparer # with the new setting self.identifier_preparer = self.preparer( - self, server_ansiquotes=self._server_ansiquotes) + self, server_ansiquotes=self._server_ansiquotes + ) default.DefaultDialect.initialize(self, connection) self._needs_correct_for_88718 = ( - not self._is_mariadb and self.server_version_info >= (8, ) + not self._is_mariadb and self.server_version_info >= (8,) ) self._warn_for_known_db_issues() @@ -2007,20 +2368,23 @@ class MySQLDialect(default.DefaultDialect): "additional issue prevents proper migrations of columns " "with CHECK constraints (MDEV-11114). Please upgrade to " "MariaDB 10.2.9 or greater, or use the MariaDB 10.1 " - "series, to avoid these issues." % (mdb_version, )) + "series, to avoid these issues." % (mdb_version,) + ) @property def _is_mariadb(self): - return 'MariaDB' in self.server_version_info + return "MariaDB" in self.server_version_info @property def _is_mysql(self): - return 'MariaDB' not in self.server_version_info + return "MariaDB" not in self.server_version_info @property def _is_mariadb_102(self): - return self._is_mariadb and \ - self._mariadb_normalized_version_info > (10, 2) + return self._is_mariadb and self._mariadb_normalized_version_info > ( + 10, + 2, + ) @property def _mariadb_normalized_version_info(self): @@ -2028,15 +2392,17 @@ class MySQLDialect(default.DefaultDialect): # the string "5.5"; now that we use @@version we no longer see this. if self._is_mariadb: - idx = self.server_version_info.index('MariaDB') - return self.server_version_info[idx - 3: idx] + idx = self.server_version_info.index("MariaDB") + return self.server_version_info[idx - 3 : idx] else: return self.server_version_info @property def _supports_cast(self): - return self.server_version_info is None or \ - self.server_version_info >= (4, 0, 2) + return ( + self.server_version_info is None + or self.server_version_info >= (4, 0, 2) + ) @reflection.cache def get_schema_names(self, connection, **kw): @@ -2054,18 +2420,23 @@ class MySQLDialect(default.DefaultDialect): charset = self._connection_charset if self.server_version_info < (5, 0, 2): rp = connection.execute( - "SHOW TABLES FROM %s" % - self.identifier_preparer.quote_identifier(current_schema)) - return [row[0] for - row in self._compat_fetchall(rp, charset=charset)] + "SHOW TABLES FROM %s" + % self.identifier_preparer.quote_identifier(current_schema) + ) + return [ + row[0] for row in self._compat_fetchall(rp, charset=charset) + ] else: rp = connection.execute( - "SHOW FULL TABLES FROM %s" % - self.identifier_preparer.quote_identifier(current_schema)) + "SHOW FULL TABLES FROM %s" + % self.identifier_preparer.quote_identifier(current_schema) + ) - return [row[0] - for row in self._compat_fetchall(rp, charset=charset) - if row[1] == 'BASE TABLE'] + return [ + row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] == "BASE TABLE" + ] @reflection.cache def get_view_names(self, connection, schema=None, **kw): @@ -2077,72 +2448,77 @@ class MySQLDialect(default.DefaultDialect): return self.get_table_names(connection, schema) charset = self._connection_charset rp = connection.execute( - "SHOW FULL TABLES FROM %s" % - self.identifier_preparer.quote_identifier(schema)) - return [row[0] - for row in self._compat_fetchall(rp, charset=charset) - if row[1] in ('VIEW', 'SYSTEM VIEW')] + "SHOW FULL TABLES FROM %s" + % self.identifier_preparer.quote_identifier(schema) + ) + return [ + row[0] + for row in self._compat_fetchall(rp, charset=charset) + if row[1] in ("VIEW", "SYSTEM VIEW") + ] @reflection.cache def get_table_options(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) return parsed_state.table_options @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) return parsed_state.columns @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) for key in parsed_state.keys: - if key['type'] == 'PRIMARY': + if key["type"] == "PRIMARY": # There can be only one. - cols = [s[0] for s in key['columns']] - return {'constrained_columns': cols, 'name': None} - return {'constrained_columns': [], 'name': None} + cols = [s[0] for s in key["columns"]] + return {"constrained_columns": cols, "name": None} + return {"constrained_columns": [], "name": None} @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) default_schema = None fkeys = [] for spec in parsed_state.fk_constraints: - ref_name = spec['table'][-1] - ref_schema = len(spec['table']) > 1 and \ - spec['table'][-2] or schema + ref_name = spec["table"][-1] + ref_schema = len(spec["table"]) > 1 and spec["table"][-2] or schema if not ref_schema: if default_schema is None: - default_schema = \ - connection.dialect.default_schema_name + default_schema = connection.dialect.default_schema_name if schema == default_schema: ref_schema = schema - loc_names = spec['local'] - ref_names = spec['foreign'] + loc_names = spec["local"] + ref_names = spec["foreign"] con_kw = {} - for opt in ('onupdate', 'ondelete'): + for opt in ("onupdate", "ondelete"): if spec.get(opt, False): con_kw[opt] = spec[opt] fkey_d = { - 'name': spec['name'], - 'constrained_columns': loc_names, - 'referred_schema': ref_schema, - 'referred_table': ref_name, - 'referred_columns': ref_names, - 'options': con_kw + "name": spec["name"], + "constrained_columns": loc_names, + "referred_schema": ref_schema, + "referred_table": ref_name, + "referred_columns": ref_names, + "options": con_kw, } fkeys.append(fkey_d) @@ -2172,25 +2548,26 @@ class MySQLDialect(default.DefaultDialect): default_schema_name = connection.dialect.default_schema_name col_tuples = [ ( - lower(rec['referred_schema'] or default_schema_name), - lower(rec['referred_table']), - col_name + lower(rec["referred_schema"] or default_schema_name), + lower(rec["referred_table"]), + col_name, ) for rec in fkeys - for col_name in rec['referred_columns'] + for col_name in rec["referred_columns"] ] if col_tuples: correct_for_wrong_fk_case = connection.execute( - sql.text(""" + sql.text( + """ select table_schema, table_name, column_name from information_schema.columns where (table_schema, table_name, lower(column_name)) in :table_data; - """).bindparams( - sql.bindparam("table_data", expanding=True) - ), table_data=col_tuples + """ + ).bindparams(sql.bindparam("table_data", expanding=True)), + table_data=col_tuples, ) # in casing=0, table name and schema name come back in their @@ -2208,109 +2585,117 @@ class MySQLDialect(default.DefaultDialect): d[(lower(schema), lower(tname))][cname.lower()] = cname for fkey in fkeys: - fkey['referred_columns'] = [ + fkey["referred_columns"] = [ d[ ( lower( - fkey['referred_schema'] or - default_schema_name), - lower(fkey['referred_table']) + fkey["referred_schema"] or default_schema_name + ), + lower(fkey["referred_table"]), ) ][col.lower()] - for col in fkey['referred_columns'] + for col in fkey["referred_columns"] ] @reflection.cache - def get_check_constraints( - self, connection, table_name, schema=None, **kw): + def get_check_constraints(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) return [ - {"name": spec['name'], "sqltext": spec['sqltext']} + {"name": spec["name"], "sqltext": spec["sqltext"]} for spec in parsed_state.ck_constraints ] @reflection.cache def get_table_comment(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) - return {"text": parsed_state.table_options.get('mysql_comment', None)} + connection, table_name, schema, **kw + ) + return {"text": parsed_state.table_options.get("mysql_comment", None)} @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) indexes = [] for spec in parsed_state.keys: dialect_options = {} unique = False - flavor = spec['type'] - if flavor == 'PRIMARY': + flavor = spec["type"] + if flavor == "PRIMARY": continue - if flavor == 'UNIQUE': + if flavor == "UNIQUE": unique = True - elif flavor in ('FULLTEXT', 'SPATIAL'): + elif flavor in ("FULLTEXT", "SPATIAL"): dialect_options["mysql_prefix"] = flavor elif flavor is None: pass else: self.logger.info( - "Converting unknown KEY type %s to a plain KEY", flavor) + "Converting unknown KEY type %s to a plain KEY", flavor + ) pass - if spec['parser']: - dialect_options['mysql_with_parser'] = spec['parser'] + if spec["parser"]: + dialect_options["mysql_with_parser"] = spec["parser"] index_d = {} if dialect_options: index_d["dialect_options"] = dialect_options - index_d['name'] = spec['name'] - index_d['column_names'] = [s[0] for s in spec['columns']] - index_d['unique'] = unique + index_d["name"] = spec["name"] + index_d["column_names"] = [s[0] for s in spec["columns"]] + index_d["unique"] = unique if flavor: - index_d['type'] = flavor + index_d["type"] = flavor indexes.append(index_d) return indexes @reflection.cache - def get_unique_constraints(self, connection, table_name, - schema=None, **kw): + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): parsed_state = self._parsed_state_or_create( - connection, table_name, schema, **kw) + connection, table_name, schema, **kw + ) return [ { - 'name': key['name'], - 'column_names': [col[0] for col in key['columns']], - 'duplicates_index': key['name'], + "name": key["name"], + "column_names": [col[0] for col in key["columns"]], + "duplicates_index": key["name"], } for key in parsed_state.keys - if key['type'] == 'UNIQUE' + if key["type"] == "UNIQUE" ] @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): charset = self._connection_charset - full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( - schema, view_name)) - sql = self._show_create_table(connection, None, charset, - full_name=full_name) + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers(schema, view_name) + ) + sql = self._show_create_table( + connection, None, charset, full_name=full_name + ) return sql - def _parsed_state_or_create(self, connection, table_name, - schema=None, **kw): + def _parsed_state_or_create( + self, connection, table_name, schema=None, **kw + ): return self._setup_parser( connection, table_name, schema, - info_cache=kw.get('info_cache', None) + info_cache=kw.get("info_cache", None), ) @util.memoized_property @@ -2321,7 +2706,7 @@ class MySQLDialect(default.DefaultDialect): retrieved server version information first. """ - if (self.server_version_info < (4, 1) and self._server_ansiquotes): + if self.server_version_info < (4, 1) and self._server_ansiquotes: # ANSI_QUOTES doesn't affect SHOW CREATE TABLE on < 4.1 preparer = self.preparer(self, server_ansiquotes=False) else: @@ -2332,14 +2717,19 @@ class MySQLDialect(default.DefaultDialect): def _setup_parser(self, connection, table_name, schema=None, **kw): charset = self._connection_charset parser = self._tabledef_parser - full_name = '.'.join(self.identifier_preparer._quote_free_identifiers( - schema, table_name)) - sql = self._show_create_table(connection, None, charset, - full_name=full_name) - if re.match(r'^CREATE (?:ALGORITHM)?.* VIEW', sql): + full_name = ".".join( + self.identifier_preparer._quote_free_identifiers( + schema, table_name + ) + ) + sql = self._show_create_table( + connection, None, charset, full_name=full_name + ) + if re.match(r"^CREATE (?:ALGORITHM)?.* VIEW", sql): # Adapt views to something table-like. - columns = self._describe_table(connection, None, charset, - full_name=full_name) + columns = self._describe_table( + connection, None, charset, full_name=full_name + ) sql = parser._describe_to_create(table_name, columns) return parser.parse(sql, charset) @@ -2356,17 +2746,18 @@ class MySQLDialect(default.DefaultDialect): # http://dev.mysql.com/doc/refman/5.0/en/name-case-sensitivity.html charset = self._connection_charset - row = self._compat_first(connection.execute( - "SHOW VARIABLES LIKE 'lower_case_table_names'"), - charset=charset) + row = self._compat_first( + connection.execute("SHOW VARIABLES LIKE 'lower_case_table_names'"), + charset=charset, + ) if not row: cs = 0 else: # 4.0.15 returns OFF or ON according to [ticket:489] # 3.23 doesn't, 4.0.27 doesn't.. - if row[1] == 'OFF': + if row[1] == "OFF": cs = 0 - elif row[1] == 'ON': + elif row[1] == "ON": cs = 1 else: cs = int(row[1]) @@ -2384,7 +2775,7 @@ class MySQLDialect(default.DefaultDialect): pass else: charset = self._connection_charset - rs = connection.execute('SHOW COLLATION') + rs = connection.execute("SHOW COLLATION") for row in self._compat_fetchall(rs, charset): collations[row[0]] = row[1] return collations @@ -2392,33 +2783,36 @@ class MySQLDialect(default.DefaultDialect): def _detect_sql_mode(self, connection): row = self._compat_first( connection.execute("SHOW VARIABLES LIKE 'sql_mode'"), - charset=self._connection_charset) + charset=self._connection_charset, + ) if not row: util.warn( "Could not retrieve SQL_MODE; please ensure the " - "MySQL user has permissions to SHOW VARIABLES") - self._sql_mode = '' + "MySQL user has permissions to SHOW VARIABLES" + ) + self._sql_mode = "" else: - self._sql_mode = row[1] or '' + self._sql_mode = row[1] or "" def _detect_ansiquotes(self, connection): """Detect and adjust for the ANSI_QUOTES sql mode.""" mode = self._sql_mode if not mode: - mode = '' + mode = "" elif mode.isdigit(): mode_no = int(mode) - mode = (mode_no | 4 == mode_no) and 'ANSI_QUOTES' or '' + mode = (mode_no | 4 == mode_no) and "ANSI_QUOTES" or "" - self._server_ansiquotes = 'ANSI_QUOTES' in mode + self._server_ansiquotes = "ANSI_QUOTES" in mode # as of MySQL 5.0.1 - self._backslash_escapes = 'NO_BACKSLASH_ESCAPES' not in mode + self._backslash_escapes = "NO_BACKSLASH_ESCAPES" not in mode - def _show_create_table(self, connection, table, charset=None, - full_name=None): + def _show_create_table( + self, connection, table, charset=None, full_name=None + ): """Run SHOW CREATE TABLE for a ``Table``.""" if full_name is None: @@ -2428,7 +2822,8 @@ class MySQLDialect(default.DefaultDialect): rp = None try: rp = connection.execution_options( - skip_user_error_events=True).execute(st) + skip_user_error_events=True + ).execute(st) except exc.DBAPIError as e: if self._extract_error_code(e.orig) == 1146: raise exc.NoSuchTableError(full_name) @@ -2441,8 +2836,7 @@ class MySQLDialect(default.DefaultDialect): return sql - def _describe_table(self, connection, table, charset=None, - full_name=None): + def _describe_table(self, connection, table, charset=None, full_name=None): """Run DESCRIBE for a ``Table`` and return processed rows.""" if full_name is None: @@ -2453,7 +2847,8 @@ class MySQLDialect(default.DefaultDialect): try: try: rp = connection.execution_options( - skip_user_error_events=True).execute(st) + skip_user_error_events=True + ).execute(st) except exc.DBAPIError as e: code = self._extract_error_code(e.orig) if code == 1146: @@ -2486,11 +2881,11 @@ class _DecodingRowProxy(object): # seem to come up in DDL queries. _encoding_compat = { - 'koi8r': 'koi8_r', - 'koi8u': 'koi8_u', - 'utf16': 'utf-16-be', # MySQL's uft16 is always bigendian - 'utf8mb4': 'utf8', # real utf8 - 'eucjpms': 'ujis', + "koi8r": "koi8_r", + "koi8u": "koi8_u", + "utf16": "utf-16-be", # MySQL's uft16 is always bigendian + "utf8mb4": "utf8", # real utf8 + "eucjpms": "ujis", } def __init__(self, rowproxy, charset): diff --git a/lib/sqlalchemy/dialects/mysql/cymysql.py b/lib/sqlalchemy/dialects/mysql/cymysql.py index d14290594..8a60608db 100644 --- a/lib/sqlalchemy/dialects/mysql/cymysql.py +++ b/lib/sqlalchemy/dialects/mysql/cymysql.py @@ -18,7 +18,7 @@ import re from .mysqldb import MySQLDialect_mysqldb -from .base import (BIT, MySQLDialect) +from .base import BIT, MySQLDialect from ... import util @@ -34,27 +34,23 @@ class _cymysqlBIT(BIT): v = v << 8 | i return v return value + return process class MySQLDialect_cymysql(MySQLDialect_mysqldb): - driver = 'cymysql' + driver = "cymysql" description_encoding = None supports_sane_rowcount = True supports_sane_multi_rowcount = False supports_unicode_statements = True - colspecs = util.update_copy( - MySQLDialect.colspecs, - { - BIT: _cymysqlBIT, - } - ) + colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT}) @classmethod def dbapi(cls): - return __import__('cymysql') + return __import__("cymysql") def _detect_charset(self, connection): return connection.connection.charset @@ -64,8 +60,13 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb): def is_disconnect(self, e, connection, cursor): if isinstance(e, self.dbapi.OperationalError): - return self._extract_error_code(e) in \ - (2006, 2013, 2014, 2045, 2055) + return self._extract_error_code(e) in ( + 2006, + 2013, + 2014, + 2045, + 2055, + ) elif isinstance(e, self.dbapi.InterfaceError): # if underlying connection is closed, # this is the error you get @@ -73,4 +74,5 @@ class MySQLDialect_cymysql(MySQLDialect_mysqldb): else: return False + dialect = MySQLDialect_cymysql diff --git a/lib/sqlalchemy/dialects/mysql/dml.py b/lib/sqlalchemy/dialects/mysql/dml.py index 130ef2347..5d59b2073 100644 --- a/lib/sqlalchemy/dialects/mysql/dml.py +++ b/lib/sqlalchemy/dialects/mysql/dml.py @@ -6,7 +6,7 @@ from ...sql.base import _generative from ... import exc from ... import util -__all__ = ('Insert', 'insert') +__all__ = ("Insert", "insert") class Insert(StandardInsert): @@ -39,7 +39,7 @@ class Insert(StandardInsert): @util.memoized_property def inserted_alias(self): - return alias(self.table, name='inserted') + return alias(self.table, name="inserted") @_generative def on_duplicate_key_update(self, *args, **kw): @@ -87,27 +87,29 @@ class Insert(StandardInsert): """ if args and kw: raise exc.ArgumentError( - "Can't pass kwargs and positional arguments simultaneously") + "Can't pass kwargs and positional arguments simultaneously" + ) if args: if len(args) > 1: raise exc.ArgumentError( "Only a single dictionary or list of tuples " - "is accepted positionally.") + "is accepted positionally." + ) values = args[0] else: values = kw - inserted_alias = getattr(self, 'inserted_alias', None) + inserted_alias = getattr(self, "inserted_alias", None) self._post_values_clause = OnDuplicateClause(inserted_alias, values) return self -insert = public_factory(Insert, '.dialects.mysql.insert') +insert = public_factory(Insert, ".dialects.mysql.insert") class OnDuplicateClause(ClauseElement): - __visit_name__ = 'on_duplicate_key_update' + __visit_name__ = "on_duplicate_key_update" _parameter_ordering = None @@ -118,11 +120,12 @@ class OnDuplicateClause(ClauseElement): # Update._proces_colparams(), however we don't look for a special flag # in this case since we are not disambiguating from other use cases as # we are in Update.values(). - if isinstance(update, list) and \ - (update and isinstance(update[0], tuple)): + if isinstance(update, list) and ( + update and isinstance(update[0], tuple) + ): self._parameter_ordering = [key for key, value in update] update = dict(update) if not update or not isinstance(update, dict): - raise ValueError('update parameter must be a non-empty dictionary') + raise ValueError("update parameter must be a non-empty dictionary") self.update = update diff --git a/lib/sqlalchemy/dialects/mysql/enumerated.py b/lib/sqlalchemy/dialects/mysql/enumerated.py index f63d64e8f..9586eff3f 100644 --- a/lib/sqlalchemy/dialects/mysql/enumerated.py +++ b/lib/sqlalchemy/dialects/mysql/enumerated.py @@ -14,29 +14,30 @@ from ...sql import sqltypes class _EnumeratedValues(_StringType): def _init_values(self, values, kw): - self.quoting = kw.pop('quoting', 'auto') + self.quoting = kw.pop("quoting", "auto") - if self.quoting == 'auto' and len(values): + if self.quoting == "auto" and len(values): # What quoting character are we using? q = None for e in values: if len(e) == 0: - self.quoting = 'unquoted' + self.quoting = "unquoted" break elif q is None: q = e[0] if len(e) == 1 or e[0] != q or e[-1] != q: - self.quoting = 'unquoted' + self.quoting = "unquoted" break else: - self.quoting = 'quoted' + self.quoting = "quoted" - if self.quoting == 'quoted': + if self.quoting == "quoted": util.warn_deprecated( - 'Manually quoting %s value literals is deprecated. Supply ' - 'unquoted values and use the quoting= option in cases of ' - 'ambiguity.' % self.__class__.__name__) + "Manually quoting %s value literals is deprecated. Supply " + "unquoted values and use the quoting= option in cases of " + "ambiguity." % self.__class__.__name__ + ) values = self._strip_values(values) @@ -58,7 +59,7 @@ class _EnumeratedValues(_StringType): class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _EnumeratedValues): """MySQL ENUM type.""" - __visit_name__ = 'ENUM' + __visit_name__ = "ENUM" native_enum = True @@ -115,7 +116,7 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _EnumeratedValues): """ - kw.pop('strict', None) + kw.pop("strict", None) self._enum_init(enums, kw) _StringType.__init__(self, length=self.length, **kw) @@ -145,13 +146,14 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _EnumeratedValues): def __repr__(self): return util.generic_repr( - self, to_inspect=[ENUM, _StringType, sqltypes.Enum]) + self, to_inspect=[ENUM, _StringType, sqltypes.Enum] + ) class SET(_EnumeratedValues): """MySQL SET type.""" - __visit_name__ = 'SET' + __visit_name__ = "SET" def __init__(self, *values, **kw): """Construct a SET. @@ -216,45 +218,43 @@ class SET(_EnumeratedValues): """ - self.retrieve_as_bitwise = kw.pop('retrieve_as_bitwise', False) + self.retrieve_as_bitwise = kw.pop("retrieve_as_bitwise", False) values, length = self._init_values(values, kw) self.values = tuple(values) - if not self.retrieve_as_bitwise and '' in values: + if not self.retrieve_as_bitwise and "" in values: raise exc.ArgumentError( "Can't use the blank value '' in a SET without " - "setting retrieve_as_bitwise=True") + "setting retrieve_as_bitwise=True" + ) if self.retrieve_as_bitwise: self._bitmap = dict( - (value, 2 ** idx) - for idx, value in enumerate(self.values) + (value, 2 ** idx) for idx, value in enumerate(self.values) ) self._bitmap.update( - (2 ** idx, value) - for idx, value in enumerate(self.values) + (2 ** idx, value) for idx, value in enumerate(self.values) ) - kw.setdefault('length', length) + kw.setdefault("length", length) super(SET, self).__init__(**kw) def column_expression(self, colexpr): if self.retrieve_as_bitwise: return sql.type_coerce( - sql.type_coerce(colexpr, sqltypes.Integer) + 0, - self + sql.type_coerce(colexpr, sqltypes.Integer) + 0, self ) else: return colexpr def result_processor(self, dialect, coltype): if self.retrieve_as_bitwise: + def process(value): if value is not None: value = int(value) - return set( - util.map_bits(self._bitmap.__getitem__, value) - ) + return set(util.map_bits(self._bitmap.__getitem__, value)) else: return None + else: super_convert = super(SET, self).result_processor(dialect, coltype) @@ -263,18 +263,20 @@ class SET(_EnumeratedValues): # MySQLdb returns a string, let's parse if super_convert: value = super_convert(value) - return set(re.findall(r'[^,]+', value)) + return set(re.findall(r"[^,]+", value)) else: # mysql-connector-python does a naive # split(",") which throws in an empty string if value is not None: - value.discard('') + value.discard("") return value + return process def bind_processor(self, dialect): super_convert = super(SET, self).bind_processor(dialect) if self.retrieve_as_bitwise: + def process(value): if value is None: return None @@ -288,24 +290,23 @@ class SET(_EnumeratedValues): for v in value: int_value |= self._bitmap[v] return int_value + else: def process(value): # accept strings and int (actually bitflag) values directly if value is not None and not isinstance( - value, util.int_types + util.string_types): + value, util.int_types + util.string_types + ): value = ",".join(value) if super_convert: return super_convert(value) else: return value + return process def adapt(self, impltype, **kw): - kw['retrieve_as_bitwise'] = self.retrieve_as_bitwise - return util.constructor_copy( - self, impltype, - *self.values, - **kw - ) + kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise + return util.constructor_copy(self, impltype, *self.values, **kw) diff --git a/lib/sqlalchemy/dialects/mysql/gaerdbms.py b/lib/sqlalchemy/dialects/mysql/gaerdbms.py index 806e4c874..117cd28a2 100644 --- a/lib/sqlalchemy/dialects/mysql/gaerdbms.py +++ b/lib/sqlalchemy/dialects/mysql/gaerdbms.py @@ -44,11 +44,10 @@ from sqlalchemy.util import warn_deprecated def _is_dev_environment(): - return os.environ.get('SERVER_SOFTWARE', '').startswith('Development/') + return os.environ.get("SERVER_SOFTWARE", "").startswith("Development/") class MySQLDialect_gaerdbms(MySQLDialect_mysqldb): - @classmethod def dbapi(cls): @@ -69,12 +68,15 @@ class MySQLDialect_gaerdbms(MySQLDialect_mysqldb): if _is_dev_environment(): from google.appengine.api import rdbms_mysqldb + return rdbms_mysqldb - elif apiproxy_stub_map.apiproxy.GetStub('rdbms'): + elif apiproxy_stub_map.apiproxy.GetStub("rdbms"): from google.storage.speckle.python.api import rdbms_apiproxy + return rdbms_apiproxy else: from google.storage.speckle.python.api import rdbms_googleapi + return rdbms_googleapi @classmethod @@ -87,8 +89,8 @@ class MySQLDialect_gaerdbms(MySQLDialect_mysqldb): if not _is_dev_environment(): # 'dsn' and 'instance' are because we are skipping # the traditional google.api.rdbms wrapper - opts['dsn'] = '' - opts['instance'] = url.query['instance'] + opts["dsn"] = "" + opts["instance"] = url.query["instance"] return [], opts def _extract_error_code(self, exception): @@ -99,4 +101,5 @@ class MySQLDialect_gaerdbms(MySQLDialect_mysqldb): if code: return int(code) + dialect = MySQLDialect_gaerdbms diff --git a/lib/sqlalchemy/dialects/mysql/json.py b/lib/sqlalchemy/dialects/mysql/json.py index 534fb989d..162d48f73 100644 --- a/lib/sqlalchemy/dialects/mysql/json.py +++ b/lib/sqlalchemy/dialects/mysql/json.py @@ -58,7 +58,6 @@ class _FormatTypeMixin(object): class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): - def _format_value(self, value): if isinstance(value, int): value = "$[%s]" % value @@ -70,8 +69,10 @@ class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): def _format_value(self, value): return "$%s" % ( - "".join([ - "[%s]" % elem if isinstance(elem, int) - else '."%s"' % elem for elem in value - ]) + "".join( + [ + "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem + for elem in value + ] + ) ) diff --git a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py index e16b68bad..9c1502a14 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqlconnector.py +++ b/lib/sqlalchemy/dialects/mysql/mysqlconnector.py @@ -47,9 +47,13 @@ are contributed to SQLAlchemy. """ -from .base import (MySQLDialect, MySQLExecutionContext, - MySQLCompiler, MySQLIdentifierPreparer, - BIT) +from .base import ( + MySQLDialect, + MySQLExecutionContext, + MySQLCompiler, + MySQLIdentifierPreparer, + BIT, +) from ... import util import re @@ -57,7 +61,6 @@ from ... import processors class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext): - def get_lastrowid(self): return self.cursor.lastrowid @@ -65,21 +68,27 @@ class MySQLExecutionContext_mysqlconnector(MySQLExecutionContext): class MySQLCompiler_mysqlconnector(MySQLCompiler): def visit_mod_binary(self, binary, operator, **kw): if self.dialect._mysqlconnector_double_percents: - return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) else: - return self.process(binary.left, **kw) + " % " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " % " + + self.process(binary.right, **kw) + ) def post_process_text(self, text): if self.dialect._mysqlconnector_double_percents: - return text.replace('%', '%%') + return text.replace("%", "%%") else: return text def escape_literal_column(self, text): if self.dialect._mysqlconnector_double_percents: - return text.replace('%', '%%') + return text.replace("%", "%%") else: return text @@ -109,7 +118,7 @@ class _myconnpyBIT(BIT): class MySQLDialect_mysqlconnector(MySQLDialect): - driver = 'mysqlconnector' + driver = "mysqlconnector" supports_unicode_binds = True @@ -118,28 +127,22 @@ class MySQLDialect_mysqlconnector(MySQLDialect): supports_native_decimal = True - default_paramstyle = 'format' + default_paramstyle = "format" execution_ctx_cls = MySQLExecutionContext_mysqlconnector statement_compiler = MySQLCompiler_mysqlconnector preparer = MySQLIdentifierPreparer_mysqlconnector - colspecs = util.update_copy( - MySQLDialect.colspecs, - { - BIT: _myconnpyBIT, - } - ) + colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT}) def __init__(self, *arg, **kw): super(MySQLDialect_mysqlconnector, self).__init__(*arg, **kw) # hack description encoding since mysqlconnector randomly # returns bytes or not - self._description_decoder = \ - processors.to_conditional_unicode_processor_factory( - self.description_encoding - ) + self._description_decoder = processors.to_conditional_unicode_processor_factory( + self.description_encoding + ) def _check_unicode_description(self, connection): # hack description encoding since mysqlconnector randomly @@ -158,6 +161,7 @@ class MySQLDialect_mysqlconnector(MySQLDialect): @classmethod def dbapi(cls): from mysql import connector + return connector def do_ping(self, dbapi_connection): @@ -172,54 +176,52 @@ class MySQLDialect_mysqlconnector(MySQLDialect): return True def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') + opts = url.translate_connect_args(username="user") opts.update(url.query) - util.coerce_kw_type(opts, 'allow_local_infile', bool) - util.coerce_kw_type(opts, 'autocommit', bool) - util.coerce_kw_type(opts, 'buffered', bool) - util.coerce_kw_type(opts, 'compress', bool) - util.coerce_kw_type(opts, 'connection_timeout', int) - util.coerce_kw_type(opts, 'connect_timeout', int) - util.coerce_kw_type(opts, 'consume_results', bool) - util.coerce_kw_type(opts, 'force_ipv6', bool) - util.coerce_kw_type(opts, 'get_warnings', bool) - util.coerce_kw_type(opts, 'pool_reset_session', bool) - util.coerce_kw_type(opts, 'pool_size', int) - util.coerce_kw_type(opts, 'raise_on_warnings', bool) - util.coerce_kw_type(opts, 'raw', bool) - util.coerce_kw_type(opts, 'ssl_verify_cert', bool) - util.coerce_kw_type(opts, 'use_pure', bool) - util.coerce_kw_type(opts, 'use_unicode', bool) + util.coerce_kw_type(opts, "allow_local_infile", bool) + util.coerce_kw_type(opts, "autocommit", bool) + util.coerce_kw_type(opts, "buffered", bool) + util.coerce_kw_type(opts, "compress", bool) + util.coerce_kw_type(opts, "connection_timeout", int) + util.coerce_kw_type(opts, "connect_timeout", int) + util.coerce_kw_type(opts, "consume_results", bool) + util.coerce_kw_type(opts, "force_ipv6", bool) + util.coerce_kw_type(opts, "get_warnings", bool) + util.coerce_kw_type(opts, "pool_reset_session", bool) + util.coerce_kw_type(opts, "pool_size", int) + util.coerce_kw_type(opts, "raise_on_warnings", bool) + util.coerce_kw_type(opts, "raw", bool) + util.coerce_kw_type(opts, "ssl_verify_cert", bool) + util.coerce_kw_type(opts, "use_pure", bool) + util.coerce_kw_type(opts, "use_unicode", bool) # unfortunately, MySQL/connector python refuses to release a # cursor without reading fully, so non-buffered isn't an option - opts.setdefault('buffered', True) + opts.setdefault("buffered", True) # FOUND_ROWS must be set in ClientFlag to enable # supports_sane_rowcount. if self.dbapi is not None: try: from mysql.connector.constants import ClientFlag + client_flags = opts.get( - 'client_flags', ClientFlag.get_default()) + "client_flags", ClientFlag.get_default() + ) client_flags |= ClientFlag.FOUND_ROWS - opts['client_flags'] = client_flags + opts["client_flags"] = client_flags except Exception: pass return [[], opts] @util.memoized_property def _mysqlconnector_version_info(self): - if self.dbapi and hasattr(self.dbapi, '__version__'): - m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?', - self.dbapi.__version__) + if self.dbapi and hasattr(self.dbapi, "__version__"): + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__) if m: - return tuple( - int(x) - for x in m.group(1, 2, 3) - if x is not None) + return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) @util.memoized_property def _mysqlconnector_double_percents(self): @@ -235,9 +237,11 @@ class MySQLDialect_mysqlconnector(MySQLDialect): errnos = (2006, 2013, 2014, 2045, 2055, 2048) exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError) if isinstance(e, exceptions): - return e.errno in errnos or \ - "MySQL Connection not available." in str(e) or \ - "Connection to MySQL is not available" in str(e) + return ( + e.errno in errnos + or "MySQL Connection not available." in str(e) + or "Connection to MySQL is not available" in str(e) + ) else: return False @@ -247,17 +251,24 @@ class MySQLDialect_mysqlconnector(MySQLDialect): def _compat_fetchone(self, rp, charset=None): return rp.fetchone() - _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED', - 'READ COMMITTED', 'REPEATABLE READ', - 'AUTOCOMMIT']) + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + ] + ) def _set_isolation_level(self, connection, level): - if level == 'AUTOCOMMIT': + if level == "AUTOCOMMIT": connection.autocommit = True else: connection.autocommit = False super(MySQLDialect_mysqlconnector, self)._set_isolation_level( - connection, level) + connection, level + ) dialect = MySQLDialect_mysqlconnector diff --git a/lib/sqlalchemy/dialects/mysql/mysqldb.py b/lib/sqlalchemy/dialects/mysql/mysqldb.py index edac816fe..6d42f5c04 100644 --- a/lib/sqlalchemy/dialects/mysql/mysqldb.py +++ b/lib/sqlalchemy/dialects/mysql/mysqldb.py @@ -45,8 +45,12 @@ The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`. """ -from .base import (MySQLDialect, MySQLExecutionContext, - MySQLCompiler, MySQLIdentifierPreparer) +from .base import ( + MySQLDialect, + MySQLExecutionContext, + MySQLCompiler, + MySQLIdentifierPreparer, +) from .base import TEXT from ... import sql from ... import util @@ -54,10 +58,9 @@ import re class MySQLExecutionContext_mysqldb(MySQLExecutionContext): - @property def rowcount(self): - if hasattr(self, '_rowcount'): + if hasattr(self, "_rowcount"): return self._rowcount else: return self.cursor.rowcount @@ -72,14 +75,14 @@ class MySQLIdentifierPreparer_mysqldb(MySQLIdentifierPreparer): class MySQLDialect_mysqldb(MySQLDialect): - driver = 'mysqldb' + driver = "mysqldb" supports_unicode_statements = True supports_sane_rowcount = True supports_sane_multi_rowcount = True supports_native_decimal = True - default_paramstyle = 'format' + default_paramstyle = "format" execution_ctx_cls = MySQLExecutionContext_mysqldb statement_compiler = MySQLCompiler_mysqldb preparer = MySQLIdentifierPreparer_mysqldb @@ -87,24 +90,23 @@ class MySQLDialect_mysqldb(MySQLDialect): def __init__(self, server_side_cursors=False, **kwargs): super(MySQLDialect_mysqldb, self).__init__(**kwargs) self.server_side_cursors = server_side_cursors - self._mysql_dbapi_version = self._parse_dbapi_version( - self.dbapi.__version__) if self.dbapi is not None \ - and hasattr(self.dbapi, '__version__') else (0, 0, 0) + self._mysql_dbapi_version = ( + self._parse_dbapi_version(self.dbapi.__version__) + if self.dbapi is not None and hasattr(self.dbapi, "__version__") + else (0, 0, 0) + ) def _parse_dbapi_version(self, version): - m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?', version) + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version) if m: - return tuple( - int(x) - for x in m.group(1, 2, 3) - if x is not None) + return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) else: return (0, 0, 0) @util.langhelpers.memoized_property def supports_server_side_cursors(self): try: - cursors = __import__('MySQLdb.cursors').cursors + cursors = __import__("MySQLdb.cursors").cursors self._sscursor = cursors.SSCursor return True except (ImportError, AttributeError): @@ -112,7 +114,7 @@ class MySQLDialect_mysqldb(MySQLDialect): @classmethod def dbapi(cls): - return __import__('MySQLdb') + return __import__("MySQLdb") def do_ping(self, dbapi_connection): try: @@ -135,67 +137,74 @@ class MySQLDialect_mysqldb(MySQLDialect): # https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8 # specific issue w/ the utf8mb4_bin collation and unicode returns - has_utf8mb4_bin = self.server_version_info > (5, ) and \ - connection.scalar( - "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'" - % ( - self.identifier_preparer.quote("Charset"), - self.identifier_preparer.quote("Collation") - )) + has_utf8mb4_bin = self.server_version_info > ( + 5, + ) and connection.scalar( + "show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'" + % ( + self.identifier_preparer.quote("Charset"), + self.identifier_preparer.quote("Collation"), + ) + ) if has_utf8mb4_bin: additional_tests = [ - sql.collate(sql.cast( - sql.literal_column( - "'test collated returns'"), - TEXT(charset='utf8mb4')), "utf8mb4_bin") + sql.collate( + sql.cast( + sql.literal_column("'test collated returns'"), + TEXT(charset="utf8mb4"), + ), + "utf8mb4_bin", + ) ] else: additional_tests = [] return super(MySQLDialect_mysqldb, self)._check_unicode_returns( - connection, additional_tests) + connection, additional_tests + ) def create_connect_args(self, url): - opts = url.translate_connect_args(database='db', username='user', - password='passwd') + opts = url.translate_connect_args( + database="db", username="user", password="passwd" + ) opts.update(url.query) - util.coerce_kw_type(opts, 'compress', bool) - util.coerce_kw_type(opts, 'connect_timeout', int) - util.coerce_kw_type(opts, 'read_timeout', int) - util.coerce_kw_type(opts, 'write_timeout', int) - util.coerce_kw_type(opts, 'client_flag', int) - util.coerce_kw_type(opts, 'local_infile', int) + util.coerce_kw_type(opts, "compress", bool) + util.coerce_kw_type(opts, "connect_timeout", int) + util.coerce_kw_type(opts, "read_timeout", int) + util.coerce_kw_type(opts, "write_timeout", int) + util.coerce_kw_type(opts, "client_flag", int) + util.coerce_kw_type(opts, "local_infile", int) # Note: using either of the below will cause all strings to be # returned as Unicode, both in raw SQL operations and with column # types like String and MSString. - util.coerce_kw_type(opts, 'use_unicode', bool) - util.coerce_kw_type(opts, 'charset', str) + util.coerce_kw_type(opts, "use_unicode", bool) + util.coerce_kw_type(opts, "charset", str) # Rich values 'cursorclass' and 'conv' are not supported via # query string. ssl = {} - keys = ['ssl_ca', 'ssl_key', 'ssl_cert', 'ssl_capath', 'ssl_cipher'] + keys = ["ssl_ca", "ssl_key", "ssl_cert", "ssl_capath", "ssl_cipher"] for key in keys: if key in opts: ssl[key[4:]] = opts[key] util.coerce_kw_type(ssl, key[4:], str) del opts[key] if ssl: - opts['ssl'] = ssl + opts["ssl"] = ssl # FOUND_ROWS must be set in CLIENT_FLAGS to enable # supports_sane_rowcount. - client_flag = opts.get('client_flag', 0) + client_flag = opts.get("client_flag", 0) if self.dbapi is not None: try: CLIENT_FLAGS = __import__( - self.dbapi.__name__ + '.constants.CLIENT' + self.dbapi.__name__ + ".constants.CLIENT" ).constants.CLIENT client_flag |= CLIENT_FLAGS.FOUND_ROWS except (AttributeError, ImportError): self.supports_sane_rowcount = False - opts['client_flag'] = client_flag + opts["client_flag"] = client_flag return [[], opts] def _extract_error_code(self, exception): @@ -213,22 +222,30 @@ class MySQLDialect_mysqldb(MySQLDialect): "No 'character_set_name' can be detected with " "this MySQL-Python version; " "please upgrade to a recent version of MySQL-Python. " - "Assuming latin1.") - return 'latin1' + "Assuming latin1." + ) + return "latin1" else: return cset_name() - _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED', - 'READ COMMITTED', 'REPEATABLE READ', - 'AUTOCOMMIT']) + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + "AUTOCOMMIT", + ] + ) def _set_isolation_level(self, connection, level): - if level == 'AUTOCOMMIT': + if level == "AUTOCOMMIT": connection.autocommit(True) else: connection.autocommit(False) - super(MySQLDialect_mysqldb, self)._set_isolation_level(connection, - level) + super(MySQLDialect_mysqldb, self)._set_isolation_level( + connection, level + ) dialect = MySQLDialect_mysqldb diff --git a/lib/sqlalchemy/dialects/mysql/oursql.py b/lib/sqlalchemy/dialects/mysql/oursql.py index 67dbb7cf2..8ba353a31 100644 --- a/lib/sqlalchemy/dialects/mysql/oursql.py +++ b/lib/sqlalchemy/dialects/mysql/oursql.py @@ -24,7 +24,7 @@ handling. import re -from .base import (BIT, MySQLDialect, MySQLExecutionContext) +from .base import BIT, MySQLDialect, MySQLExecutionContext from ... import types as sqltypes, util @@ -36,14 +36,13 @@ class _oursqlBIT(BIT): class MySQLExecutionContext_oursql(MySQLExecutionContext): - @property def plain_query(self): - return self.execution_options.get('_oursql_plain_query', False) + return self.execution_options.get("_oursql_plain_query", False) class MySQLDialect_oursql(MySQLDialect): - driver = 'oursql' + driver = "oursql" if util.py2k: supports_unicode_binds = True @@ -56,16 +55,12 @@ class MySQLDialect_oursql(MySQLDialect): execution_ctx_cls = MySQLExecutionContext_oursql colspecs = util.update_copy( - MySQLDialect.colspecs, - { - sqltypes.Time: sqltypes.Time, - BIT: _oursqlBIT, - } + MySQLDialect.colspecs, {sqltypes.Time: sqltypes.Time, BIT: _oursqlBIT} ) @classmethod def dbapi(cls): - return __import__('oursql') + return __import__("oursql") def do_execute(self, cursor, statement, parameters, context=None): """Provide an implementation of @@ -77,7 +72,7 @@ class MySQLDialect_oursql(MySQLDialect): cursor.execute(statement, parameters) def do_begin(self, connection): - connection.cursor().execute('BEGIN', plain_query=True) + connection.cursor().execute("BEGIN", plain_query=True) def _xa_query(self, connection, query, xid): if util.py2k: @@ -85,10 +80,12 @@ class MySQLDialect_oursql(MySQLDialect): else: charset = self._connection_charset arg = connection.connection._escape_string( - xid.encode(charset)).decode(charset) + xid.encode(charset) + ).decode(charset) arg = "'%s'" % arg - connection.execution_options( - _oursql_plain_query=True).execute(query % arg) + connection.execution_options(_oursql_plain_query=True).execute( + query % arg + ) # Because mysql is bad, these methods have to be # reimplemented to use _PlainQuery. Basically, some queries @@ -96,23 +93,25 @@ class MySQLDialect_oursql(MySQLDialect): # the parameterized query API, or refuse to be parameterized # in the first place. def do_begin_twophase(self, connection, xid): - self._xa_query(connection, 'XA BEGIN %s', xid) + self._xa_query(connection, "XA BEGIN %s", xid) def do_prepare_twophase(self, connection, xid): - self._xa_query(connection, 'XA END %s', xid) - self._xa_query(connection, 'XA PREPARE %s', xid) + self._xa_query(connection, "XA END %s", xid) + self._xa_query(connection, "XA PREPARE %s", xid) - def do_rollback_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: - self._xa_query(connection, 'XA END %s', xid) - self._xa_query(connection, 'XA ROLLBACK %s', xid) + self._xa_query(connection, "XA END %s", xid) + self._xa_query(connection, "XA ROLLBACK %s", xid) - def do_commit_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: self.do_prepare_twophase(connection, xid) - self._xa_query(connection, 'XA COMMIT %s', xid) + self._xa_query(connection, "XA COMMIT %s", xid) # Q: why didn't we need all these "plain_query" overrides earlier ? # am i on a newer/older version of OurSQL ? @@ -121,7 +120,7 @@ class MySQLDialect_oursql(MySQLDialect): self, connection.connect().execution_options(_oursql_plain_query=True), table_name, - schema + schema, ) def get_table_options(self, connection, table_name, schema=None, **kw): @@ -154,7 +153,7 @@ class MySQLDialect_oursql(MySQLDialect): return MySQLDialect.get_table_names( self, connection.connect().execution_options(_oursql_plain_query=True), - schema + schema, ) def get_schema_names(self, connection, **kw): @@ -166,57 +165,69 @@ class MySQLDialect_oursql(MySQLDialect): def initialize(self, connection): return MySQLDialect.initialize( - self, - connection.execution_options(_oursql_plain_query=True) + self, connection.execution_options(_oursql_plain_query=True) ) - def _show_create_table(self, connection, table, charset=None, - full_name=None): + def _show_create_table( + self, connection, table, charset=None, full_name=None + ): return MySQLDialect._show_create_table( self, - connection.contextual_connect(close_with_result=True). - execution_options(_oursql_plain_query=True), - table, charset, full_name + connection.contextual_connect( + close_with_result=True + ).execution_options(_oursql_plain_query=True), + table, + charset, + full_name, ) def is_disconnect(self, e, connection, cursor): if isinstance(e, self.dbapi.ProgrammingError): - return e.errno is None and 'cursor' not in e.args[1] \ - and e.args[1].endswith('closed') + return ( + e.errno is None + and "cursor" not in e.args[1] + and e.args[1].endswith("closed") + ) else: return e.errno in (2006, 2013, 2014, 2045, 2055) def create_connect_args(self, url): - opts = url.translate_connect_args(database='db', username='user', - password='passwd') + opts = url.translate_connect_args( + database="db", username="user", password="passwd" + ) opts.update(url.query) - util.coerce_kw_type(opts, 'port', int) - util.coerce_kw_type(opts, 'compress', bool) - util.coerce_kw_type(opts, 'autoping', bool) - util.coerce_kw_type(opts, 'raise_on_warnings', bool) + util.coerce_kw_type(opts, "port", int) + util.coerce_kw_type(opts, "compress", bool) + util.coerce_kw_type(opts, "autoping", bool) + util.coerce_kw_type(opts, "raise_on_warnings", bool) - util.coerce_kw_type(opts, 'default_charset', bool) - if opts.pop('default_charset', False): - opts['charset'] = None + util.coerce_kw_type(opts, "default_charset", bool) + if opts.pop("default_charset", False): + opts["charset"] = None else: - util.coerce_kw_type(opts, 'charset', str) - opts['use_unicode'] = opts.get('use_unicode', True) - util.coerce_kw_type(opts, 'use_unicode', bool) + util.coerce_kw_type(opts, "charset", str) + opts["use_unicode"] = opts.get("use_unicode", True) + util.coerce_kw_type(opts, "use_unicode", bool) # FOUND_ROWS must be set in CLIENT_FLAGS to enable # supports_sane_rowcount. - opts.setdefault('found_rows', True) + opts.setdefault("found_rows", True) ssl = {} - for key in ['ssl_ca', 'ssl_key', 'ssl_cert', - 'ssl_capath', 'ssl_cipher']: + for key in [ + "ssl_ca", + "ssl_key", + "ssl_cert", + "ssl_capath", + "ssl_cipher", + ]: if key in opts: ssl[key[4:]] = opts[key] util.coerce_kw_type(ssl, key[4:], str) del opts[key] if ssl: - opts['ssl'] = ssl + opts["ssl"] = ssl return [[], opts] diff --git a/lib/sqlalchemy/dialects/mysql/pymysql.py b/lib/sqlalchemy/dialects/mysql/pymysql.py index 5f176cef2..94dbfff06 100644 --- a/lib/sqlalchemy/dialects/mysql/pymysql.py +++ b/lib/sqlalchemy/dialects/mysql/pymysql.py @@ -34,7 +34,7 @@ from ...util import langhelpers, py3k class MySQLDialect_pymysql(MySQLDialect_mysqldb): - driver = 'pymysql' + driver = "pymysql" description_encoding = None @@ -51,7 +51,7 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): @langhelpers.memoized_property def supports_server_side_cursors(self): try: - cursors = __import__('pymysql.cursors').cursors + cursors = __import__("pymysql.cursors").cursors self._sscursor = cursors.SSCursor return True except (ImportError, AttributeError): @@ -59,10 +59,12 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): @classmethod def dbapi(cls): - return __import__('pymysql') + return __import__("pymysql") def is_disconnect(self, e, connection, cursor): - if super(MySQLDialect_pymysql, self).is_disconnect(e, connection, cursor): + if super(MySQLDialect_pymysql, self).is_disconnect( + e, connection, cursor + ): return True elif isinstance(e, self.dbapi.Error): return "Already closed" in str(e) @@ -70,9 +72,11 @@ class MySQLDialect_pymysql(MySQLDialect_mysqldb): return False if py3k: + def _extract_error_code(self, exception): if isinstance(exception.args[0], Exception): exception = exception.args[0] return exception.args[0] + dialect = MySQLDialect_pymysql diff --git a/lib/sqlalchemy/dialects/mysql/pyodbc.py b/lib/sqlalchemy/dialects/mysql/pyodbc.py index 718754651..91512857e 100644 --- a/lib/sqlalchemy/dialects/mysql/pyodbc.py +++ b/lib/sqlalchemy/dialects/mysql/pyodbc.py @@ -29,7 +29,6 @@ import re class MySQLExecutionContext_pyodbc(MySQLExecutionContext): - def get_lastrowid(self): cursor = self.create_cursor() cursor.execute("SELECT LAST_INSERT_ID()") @@ -46,7 +45,7 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): def __init__(self, **kw): # deal with http://code.google.com/p/pyodbc/issues/detail?id=25 - kw.setdefault('convert_unicode', True) + kw.setdefault("convert_unicode", True) super(MySQLDialect_pyodbc, self).__init__(**kw) def _detect_charset(self, connection): @@ -60,13 +59,15 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): # this can prefer the driver value. rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") opts = {row[0]: row[1] for row in self._compat_fetchall(rs)} - for key in ('character_set_connection', 'character_set'): + for key in ("character_set_connection", "character_set"): if opts.get(key, None): return opts[key] - util.warn("Could not detect the connection character set. " - "Assuming latin1.") - return 'latin1' + util.warn( + "Could not detect the connection character set. " + "Assuming latin1." + ) + return "latin1" def _extract_error_code(self, exception): m = re.compile(r"\((\d+)\)").search(str(exception.args)) @@ -76,4 +77,5 @@ class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect): else: return None + dialect = MySQLDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/mysql/reflection.py b/lib/sqlalchemy/dialects/mysql/reflection.py index e88bc3f42..d0513eb4d 100644 --- a/lib/sqlalchemy/dialects/mysql/reflection.py +++ b/lib/sqlalchemy/dialects/mysql/reflection.py @@ -36,16 +36,16 @@ class MySQLTableDefinitionParser(object): def parse(self, show_create, charset): state = ReflectedState() state.charset = charset - for line in re.split(r'\r?\n', show_create): - if line.startswith(' ' + self.preparer.initial_quote): + for line in re.split(r"\r?\n", show_create): + if line.startswith(" " + self.preparer.initial_quote): self._parse_column(line, state) # a regular table options line - elif line.startswith(') '): + elif line.startswith(") "): self._parse_table_options(line, state) # an ANSI-mode table options line - elif line == ')': + elif line == ")": pass - elif line.startswith('CREATE '): + elif line.startswith("CREATE "): self._parse_table_name(line, state) # Not present in real reflection, but may be if # loading from a file. @@ -55,11 +55,11 @@ class MySQLTableDefinitionParser(object): type_, spec = self._parse_constraints(line) if type_ is None: util.warn("Unknown schema content: %r" % line) - elif type_ == 'key': + elif type_ == "key": state.keys.append(spec) - elif type_ == 'fk_constraint': + elif type_ == "fk_constraint": state.fk_constraints.append(spec) - elif type_ == 'ck_constraint': + elif type_ == "ck_constraint": state.ck_constraints.append(spec) else: pass @@ -78,39 +78,39 @@ class MySQLTableDefinitionParser(object): # convert columns into name, length pairs # NOTE: we may want to consider SHOW INDEX as the # format of indexes in MySQL becomes more complex - spec['columns'] = self._parse_keyexprs(spec['columns']) - if spec['version_sql']: - m2 = self._re_key_version_sql.match(spec['version_sql']) - if m2 and m2.groupdict()['parser']: - spec['parser'] = m2.groupdict()['parser'] - if spec['parser']: - spec['parser'] = self.preparer.unformat_identifiers( - spec['parser'])[0] - return 'key', spec + spec["columns"] = self._parse_keyexprs(spec["columns"]) + if spec["version_sql"]: + m2 = self._re_key_version_sql.match(spec["version_sql"]) + if m2 and m2.groupdict()["parser"]: + spec["parser"] = m2.groupdict()["parser"] + if spec["parser"]: + spec["parser"] = self.preparer.unformat_identifiers( + spec["parser"] + )[0] + return "key", spec # FOREIGN KEY CONSTRAINT m = self._re_fk_constraint.match(line) if m: spec = m.groupdict() - spec['table'] = \ - self.preparer.unformat_identifiers(spec['table']) - spec['local'] = [c[0] - for c in self._parse_keyexprs(spec['local'])] - spec['foreign'] = [c[0] - for c in self._parse_keyexprs(spec['foreign'])] - return 'fk_constraint', spec + spec["table"] = self.preparer.unformat_identifiers(spec["table"]) + spec["local"] = [c[0] for c in self._parse_keyexprs(spec["local"])] + spec["foreign"] = [ + c[0] for c in self._parse_keyexprs(spec["foreign"]) + ] + return "fk_constraint", spec # CHECK constraint m = self._re_ck_constraint.match(line) if m: spec = m.groupdict() - return 'ck_constraint', spec + return "ck_constraint", spec # PARTITION and SUBPARTITION m = self._re_partition.match(line) if m: # Punt! - return 'partition', line + return "partition", line # No match. return (None, line) @@ -124,7 +124,7 @@ class MySQLTableDefinitionParser(object): regex, cleanup = self._pr_name m = regex.match(line) if m: - state.table_name = cleanup(m.group('name')) + state.table_name = cleanup(m.group("name")) def _parse_table_options(self, line, state): """Build a dictionary of all reflected table-level options. @@ -134,7 +134,7 @@ class MySQLTableDefinitionParser(object): options = {} - if not line or line == ')': + if not line or line == ")": pass else: @@ -143,17 +143,17 @@ class MySQLTableDefinitionParser(object): m = regex.search(rest_of_line) if not m: continue - directive, value = m.group('directive'), m.group('val') + directive, value = m.group("directive"), m.group("val") if cleanup: value = cleanup(value) options[directive.lower()] = value - rest_of_line = regex.sub('', rest_of_line) + rest_of_line = regex.sub("", rest_of_line) - for nope in ('auto_increment', 'data directory', 'index directory'): + for nope in ("auto_increment", "data directory", "index directory"): options.pop(nope, None) for opt, val in options.items(): - state.table_options['%s_%s' % (self.dialect.name, opt)] = val + state.table_options["%s_%s" % (self.dialect.name, opt)] = val def _parse_column(self, line, state): """Extract column details. @@ -167,29 +167,30 @@ class MySQLTableDefinitionParser(object): m = self._re_column.match(line) if m: spec = m.groupdict() - spec['full'] = True + spec["full"] = True else: m = self._re_column_loose.match(line) if m: spec = m.groupdict() - spec['full'] = False + spec["full"] = False if not spec: util.warn("Unknown column definition %r" % line) return - if not spec['full']: + if not spec["full"]: util.warn("Incomplete reflection of column definition %r" % line) - name, type_, args = spec['name'], spec['coltype'], spec['arg'] + name, type_, args = spec["name"], spec["coltype"], spec["arg"] try: col_type = self.dialect.ischema_names[type_] except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (type_, name)) + util.warn( + "Did not recognize type '%s' of column '%s'" % (type_, name) + ) col_type = sqltypes.NullType # Column type positional arguments eg. varchar(32) - if args is None or args == '': + if args is None or args == "": type_args = [] elif args[0] == "'" and args[-1] == "'": type_args = self._re_csv_str.findall(args) @@ -201,50 +202,51 @@ class MySQLTableDefinitionParser(object): if issubclass(col_type, (DATETIME, TIME, TIMESTAMP)): if type_args: - type_kw['fsp'] = type_args.pop(0) + type_kw["fsp"] = type_args.pop(0) - for kw in ('unsigned', 'zerofill'): + for kw in ("unsigned", "zerofill"): if spec.get(kw, False): type_kw[kw] = True - for kw in ('charset', 'collate'): + for kw in ("charset", "collate"): if spec.get(kw, False): type_kw[kw] = spec[kw] if issubclass(col_type, _EnumeratedValues): type_args = _EnumeratedValues._strip_values(type_args) - if issubclass(col_type, SET) and '' in type_args: - type_kw['retrieve_as_bitwise'] = True + if issubclass(col_type, SET) and "" in type_args: + type_kw["retrieve_as_bitwise"] = True type_instance = col_type(*type_args, **type_kw) col_kw = {} # NOT NULL - col_kw['nullable'] = True + col_kw["nullable"] = True # this can be "NULL" in the case of TIMESTAMP - if spec.get('notnull', False) == 'NOT NULL': - col_kw['nullable'] = False + if spec.get("notnull", False) == "NOT NULL": + col_kw["nullable"] = False # AUTO_INCREMENT - if spec.get('autoincr', False): - col_kw['autoincrement'] = True + if spec.get("autoincr", False): + col_kw["autoincrement"] = True elif issubclass(col_type, sqltypes.Integer): - col_kw['autoincrement'] = False + col_kw["autoincrement"] = False # DEFAULT - default = spec.get('default', None) + default = spec.get("default", None) - if default == 'NULL': + if default == "NULL": # eliminates the need to deal with this later. default = None - comment = spec.get('comment', None) + comment = spec.get("comment", None) if comment is not None: comment = comment.replace("\\\\", "\\").replace("''", "'") - col_d = dict(name=name, type=type_instance, default=default, - comment=comment) + col_d = dict( + name=name, type=type_instance, default=default, comment=comment + ) col_d.update(col_kw) state.columns.append(col_d) @@ -262,36 +264,44 @@ class MySQLTableDefinitionParser(object): buffer = [] for row in columns: - (name, col_type, nullable, default, extra) = \ - [row[i] for i in (0, 1, 2, 4, 5)] + (name, col_type, nullable, default, extra) = [ + row[i] for i in (0, 1, 2, 4, 5) + ] - line = [' '] + line = [" "] line.append(self.preparer.quote_identifier(name)) line.append(col_type) if not nullable: - line.append('NOT NULL') + line.append("NOT NULL") if default: - if 'auto_increment' in default: + if "auto_increment" in default: pass - elif (col_type.startswith('timestamp') and - default.startswith('C')): - line.append('DEFAULT') + elif col_type.startswith("timestamp") and default.startswith( + "C" + ): + line.append("DEFAULT") line.append(default) - elif default == 'NULL': - line.append('DEFAULT') + elif default == "NULL": + line.append("DEFAULT") line.append(default) else: - line.append('DEFAULT') + line.append("DEFAULT") line.append("'%s'" % default.replace("'", "''")) if extra: line.append(extra) - buffer.append(' '.join(line)) - - return ''.join([('CREATE TABLE %s (\n' % - self.preparer.quote_identifier(table_name)), - ',\n'.join(buffer), - '\n) ']) + buffer.append(" ".join(line)) + + return "".join( + [ + ( + "CREATE TABLE %s (\n" + % self.preparer.quote_identifier(table_name) + ), + ",\n".join(buffer), + "\n) ", + ] + ) def _parse_keyexprs(self, identifiers): """Unpack '"col"(2),"col" ASC'-ish strings into components.""" @@ -306,29 +316,39 @@ class MySQLTableDefinitionParser(object): _final = self.preparer.final_quote - quotes = dict(zip(('iq', 'fq', 'esc_fq'), - [re.escape(s) for s in - (self.preparer.initial_quote, - _final, - self.preparer._escape_identifier(_final))])) + quotes = dict( + zip( + ("iq", "fq", "esc_fq"), + [ + re.escape(s) + for s in ( + self.preparer.initial_quote, + _final, + self.preparer._escape_identifier(_final), + ) + ], + ) + ) self._pr_name = _pr_compile( - r'^CREATE (?:\w+ +)?TABLE +' - r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($' % quotes, - self.preparer._unescape_identifier) + r"^CREATE (?:\w+ +)?TABLE +" + r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($" % quotes, + self.preparer._unescape_identifier, + ) # `col`,`col2`(32),`col3`(15) DESC # self._re_keyexprs = _re_compile( - r'(?:' - r'(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)' - r'(?:\((\d+)\))?(?: +(ASC|DESC))?(?=\,|$))+' % quotes) + r"(?:" + r"(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)" + r"(?:\((\d+)\))?(?: +(ASC|DESC))?(?=\,|$))+" % quotes + ) # 'foo' or 'foo','bar' or 'fo,o','ba''a''r' - self._re_csv_str = _re_compile(r'\x27(?:\x27\x27|[^\x27])*\x27') + self._re_csv_str = _re_compile(r"\x27(?:\x27\x27|[^\x27])*\x27") # 123 or 123,456 - self._re_csv_int = _re_compile(r'\d+') + self._re_csv_int = _re_compile(r"\d+") # `colname` <type> [type opts] # (NOT NULL | NULL) @@ -356,43 +376,39 @@ class MySQLTableDefinitionParser(object): r"(?: +COLUMN_FORMAT +(?P<colfmt>\w+))?" r"(?: +STORAGE +(?P<storage>\w+))?" r"(?: +(?P<extra>.*))?" - r",?$" - % quotes + r",?$" % quotes ) # Fallback, try to parse as little as possible self._re_column_loose = _re_compile( - r' ' - r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' - r'(?P<coltype>\w+)' - r'(?:\((?P<arg>(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?' - r'.*?(?P<notnull>(?:NOT )NULL)?' - % quotes + r" " + r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"(?P<coltype>\w+)" + r"(?:\((?P<arg>(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?" + r".*?(?P<notnull>(?:NOT )NULL)?" % quotes ) # (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))? # (`col` (ASC|DESC)?, `col` (ASC|DESC)?) # KEY_BLOCK_SIZE size | WITH PARSER name /*!50100 WITH PARSER name */ self._re_key = _re_compile( - r' ' - r'(?:(?P<type>\S+) )?KEY' - r'(?: +%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?' - r'(?: +USING +(?P<using_pre>\S+))?' - r' +\((?P<columns>.+?)\)' - r'(?: +USING +(?P<using_post>\S+))?' - r'(?: +KEY_BLOCK_SIZE *[ =]? *(?P<keyblock>\S+))?' - r'(?: +WITH PARSER +(?P<parser>\S+))?' - r'(?: +COMMENT +(?P<comment>(\x27\x27|\x27([^\x27])*?\x27)+))?' - r'(?: +/\*(?P<version_sql>.+)\*/ +)?' - r',?$' - % quotes + r" " + r"(?:(?P<type>\S+) )?KEY" + r"(?: +%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?" + r"(?: +USING +(?P<using_pre>\S+))?" + r" +\((?P<columns>.+?)\)" + r"(?: +USING +(?P<using_post>\S+))?" + r"(?: +KEY_BLOCK_SIZE *[ =]? *(?P<keyblock>\S+))?" + r"(?: +WITH PARSER +(?P<parser>\S+))?" + r"(?: +COMMENT +(?P<comment>(\x27\x27|\x27([^\x27])*?\x27)+))?" + r"(?: +/\*(?P<version_sql>.+)\*/ +)?" + r",?$" % quotes ) # https://forums.mysql.com/read.php?20,567102,567111#msg-567111 # It means if the MySQL version >= \d+, execute what's in the comment self._re_key_version_sql = _re_compile( - r'\!\d+ ' - r'(?: *WITH PARSER +(?P<parser>\S+) *)?' + r"\!\d+ " r"(?: *WITH PARSER +(?P<parser>\S+) *)?" ) # CONSTRAINT `name` FOREIGN KEY (`local_col`) @@ -402,20 +418,19 @@ class MySQLTableDefinitionParser(object): # # unique constraints come back as KEYs kw = quotes.copy() - kw['on'] = 'RESTRICT|CASCADE|SET NULL|NOACTION' + kw["on"] = "RESTRICT|CASCADE|SET NULL|NOACTION" self._re_fk_constraint = _re_compile( - r' ' - r'CONSTRAINT +' - r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' - r'FOREIGN KEY +' - r'\((?P<local>[^\)]+?)\) REFERENCES +' - r'(?P<table>%(iq)s[^%(fq)s]+%(fq)s' - r'(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +' - r'\((?P<foreign>[^\)]+?)\)' - r'(?: +(?P<match>MATCH \w+))?' - r'(?: +ON DELETE (?P<ondelete>%(on)s))?' - r'(?: +ON UPDATE (?P<onupdate>%(on)s))?' - % kw + r" " + r"CONSTRAINT +" + r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"FOREIGN KEY +" + r"\((?P<local>[^\)]+?)\) REFERENCES +" + r"(?P<table>%(iq)s[^%(fq)s]+%(fq)s" + r"(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +" + r"\((?P<foreign>[^\)]+?)\)" + r"(?: +(?P<match>MATCH \w+))?" + r"(?: +ON DELETE (?P<ondelete>%(on)s))?" + r"(?: +ON UPDATE (?P<onupdate>%(on)s))?" % kw ) # CONSTRAINT `CONSTRAINT_1` CHECK (`x` > 5)' @@ -423,18 +438,17 @@ class MySQLTableDefinitionParser(object): # is returned on a line by itself, so to match without worrying # about parenthesis in the expresion we go to the end of the line self._re_ck_constraint = _re_compile( - r' ' - r'CONSTRAINT +' - r'%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +' - r'CHECK +' - r'\((?P<sqltext>.+)\),?' - % kw + r" " + r"CONSTRAINT +" + r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +" + r"CHECK +" + r"\((?P<sqltext>.+)\),?" % kw ) # PARTITION # # punt! - self._re_partition = _re_compile(r'(?:.*)(?:SUB)?PARTITION(?:.*)') + self._re_partition = _re_compile(r"(?:.*)(?:SUB)?PARTITION(?:.*)") # Table-level options (COLLATE, ENGINE, etc.) # Do the string options first, since they have quoted @@ -442,44 +456,68 @@ class MySQLTableDefinitionParser(object): for option in _options_of_type_string: self._add_option_string(option) - for option in ('ENGINE', 'TYPE', 'AUTO_INCREMENT', - 'AVG_ROW_LENGTH', 'CHARACTER SET', - 'DEFAULT CHARSET', 'CHECKSUM', - 'COLLATE', 'DELAY_KEY_WRITE', 'INSERT_METHOD', - 'MAX_ROWS', 'MIN_ROWS', 'PACK_KEYS', 'ROW_FORMAT', - 'KEY_BLOCK_SIZE'): + for option in ( + "ENGINE", + "TYPE", + "AUTO_INCREMENT", + "AVG_ROW_LENGTH", + "CHARACTER SET", + "DEFAULT CHARSET", + "CHECKSUM", + "COLLATE", + "DELAY_KEY_WRITE", + "INSERT_METHOD", + "MAX_ROWS", + "MIN_ROWS", + "PACK_KEYS", + "ROW_FORMAT", + "KEY_BLOCK_SIZE", + ): self._add_option_word(option) - self._add_option_regex('UNION', r'\([^\)]+\)') - self._add_option_regex('TABLESPACE', r'.*? STORAGE DISK') + self._add_option_regex("UNION", r"\([^\)]+\)") + self._add_option_regex("TABLESPACE", r".*? STORAGE DISK") self._add_option_regex( - 'RAID_TYPE', - r'\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+') + "RAID_TYPE", + r"\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+", + ) - _optional_equals = r'(?:\s*(?:=\s*)|\s+)' + _optional_equals = r"(?:\s*(?:=\s*)|\s+)" def _add_option_string(self, directive): - regex = (r'(?P<directive>%s)%s' - r"'(?P<val>(?:[^']|'')*?)'(?!')" % - (re.escape(directive), self._optional_equals)) - self._pr_options.append(_pr_compile( - regex, lambda v: v.replace("\\\\", "\\").replace("''", "'") - )) + regex = r"(?P<directive>%s)%s" r"'(?P<val>(?:[^']|'')*?)'(?!')" % ( + re.escape(directive), + self._optional_equals, + ) + self._pr_options.append( + _pr_compile( + regex, lambda v: v.replace("\\\\", "\\").replace("''", "'") + ) + ) def _add_option_word(self, directive): - regex = (r'(?P<directive>%s)%s' - r'(?P<val>\w+)' % - (re.escape(directive), self._optional_equals)) + regex = r"(?P<directive>%s)%s" r"(?P<val>\w+)" % ( + re.escape(directive), + self._optional_equals, + ) self._pr_options.append(_pr_compile(regex)) def _add_option_regex(self, directive, regex): - regex = (r'(?P<directive>%s)%s' - r'(?P<val>%s)' % - (re.escape(directive), self._optional_equals, regex)) + regex = r"(?P<directive>%s)%s" r"(?P<val>%s)" % ( + re.escape(directive), + self._optional_equals, + regex, + ) self._pr_options.append(_pr_compile(regex)) -_options_of_type_string = ('COMMENT', 'DATA DIRECTORY', 'INDEX DIRECTORY', - 'PASSWORD', 'CONNECTION') + +_options_of_type_string = ( + "COMMENT", + "DATA DIRECTORY", + "INDEX DIRECTORY", + "PASSWORD", + "CONNECTION", +) def _pr_compile(regex, cleanup=None): diff --git a/lib/sqlalchemy/dialects/mysql/types.py b/lib/sqlalchemy/dialects/mysql/types.py index cb09a0841..ad97a9bbe 100644 --- a/lib/sqlalchemy/dialects/mysql/types.py +++ b/lib/sqlalchemy/dialects/mysql/types.py @@ -24,28 +24,30 @@ class _NumericType(object): super(_NumericType, self).__init__(**kw) def __repr__(self): - return util.generic_repr(self, - to_inspect=[_NumericType, sqltypes.Numeric]) + return util.generic_repr( + self, to_inspect=[_NumericType, sqltypes.Numeric] + ) class _FloatType(_NumericType, sqltypes.Float): def __init__(self, precision=None, scale=None, asdecimal=True, **kw): - if isinstance(self, (REAL, DOUBLE)) and \ - ( - (precision is None and scale is not None) or - (precision is not None and scale is None) + if isinstance(self, (REAL, DOUBLE)) and ( + (precision is None and scale is not None) + or (precision is not None and scale is None) ): raise exc.ArgumentError( "You must specify both precision and scale or omit " - "both altogether.") + "both altogether." + ) super(_FloatType, self).__init__( - precision=precision, asdecimal=asdecimal, **kw) + precision=precision, asdecimal=asdecimal, **kw + ) self.scale = scale def __repr__(self): - return util.generic_repr(self, to_inspect=[_FloatType, - _NumericType, - sqltypes.Float]) + return util.generic_repr( + self, to_inspect=[_FloatType, _NumericType, sqltypes.Float] + ) class _IntegerType(_NumericType, sqltypes.Integer): @@ -54,21 +56,28 @@ class _IntegerType(_NumericType, sqltypes.Integer): super(_IntegerType, self).__init__(**kw) def __repr__(self): - return util.generic_repr(self, to_inspect=[_IntegerType, - _NumericType, - sqltypes.Integer]) + return util.generic_repr( + self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer] + ) class _StringType(sqltypes.String): """Base for MySQL string types.""" - def __init__(self, charset=None, collation=None, - ascii=False, binary=False, unicode=False, - national=False, **kw): + def __init__( + self, + charset=None, + collation=None, + ascii=False, + binary=False, + unicode=False, + national=False, + **kw + ): self.charset = charset # allow collate= or collation= - kw.setdefault('collation', kw.pop('collate', collation)) + kw.setdefault("collation", kw.pop("collate", collation)) self.ascii = ascii self.unicode = unicode @@ -77,8 +86,9 @@ class _StringType(sqltypes.String): super(_StringType, self).__init__(**kw) def __repr__(self): - return util.generic_repr(self, - to_inspect=[_StringType, sqltypes.String]) + return util.generic_repr( + self, to_inspect=[_StringType, sqltypes.String] + ) class _MatchType(sqltypes.Float, sqltypes.MatchType): @@ -88,11 +98,10 @@ class _MatchType(sqltypes.Float, sqltypes.MatchType): sqltypes.MatchType.__init__(self) - class NUMERIC(_NumericType, sqltypes.NUMERIC): """MySQL NUMERIC type.""" - __visit_name__ = 'NUMERIC' + __visit_name__ = "NUMERIC" def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a NUMERIC. @@ -110,14 +119,15 @@ class NUMERIC(_NumericType, sqltypes.NUMERIC): numeric. """ - super(NUMERIC, self).__init__(precision=precision, - scale=scale, asdecimal=asdecimal, **kw) + super(NUMERIC, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) class DECIMAL(_NumericType, sqltypes.DECIMAL): """MySQL DECIMAL type.""" - __visit_name__ = 'DECIMAL' + __visit_name__ = "DECIMAL" def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a DECIMAL. @@ -135,14 +145,15 @@ class DECIMAL(_NumericType, sqltypes.DECIMAL): numeric. """ - super(DECIMAL, self).__init__(precision=precision, scale=scale, - asdecimal=asdecimal, **kw) + super(DECIMAL, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) class DOUBLE(_FloatType): """MySQL DOUBLE type.""" - __visit_name__ = 'DOUBLE' + __visit_name__ = "DOUBLE" def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a DOUBLE. @@ -168,14 +179,15 @@ class DOUBLE(_FloatType): numeric. """ - super(DOUBLE, self).__init__(precision=precision, scale=scale, - asdecimal=asdecimal, **kw) + super(DOUBLE, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) class REAL(_FloatType, sqltypes.REAL): """MySQL REAL type.""" - __visit_name__ = 'REAL' + __visit_name__ = "REAL" def __init__(self, precision=None, scale=None, asdecimal=True, **kw): """Construct a REAL. @@ -201,14 +213,15 @@ class REAL(_FloatType, sqltypes.REAL): numeric. """ - super(REAL, self).__init__(precision=precision, scale=scale, - asdecimal=asdecimal, **kw) + super(REAL, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) class FLOAT(_FloatType, sqltypes.FLOAT): """MySQL FLOAT type.""" - __visit_name__ = 'FLOAT' + __visit_name__ = "FLOAT" def __init__(self, precision=None, scale=None, asdecimal=False, **kw): """Construct a FLOAT. @@ -226,8 +239,9 @@ class FLOAT(_FloatType, sqltypes.FLOAT): numeric. """ - super(FLOAT, self).__init__(precision=precision, scale=scale, - asdecimal=asdecimal, **kw) + super(FLOAT, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal, **kw + ) def bind_processor(self, dialect): return None @@ -236,7 +250,7 @@ class FLOAT(_FloatType, sqltypes.FLOAT): class INTEGER(_IntegerType, sqltypes.INTEGER): """MySQL INTEGER type.""" - __visit_name__ = 'INTEGER' + __visit_name__ = "INTEGER" def __init__(self, display_width=None, **kw): """Construct an INTEGER. @@ -257,7 +271,7 @@ class INTEGER(_IntegerType, sqltypes.INTEGER): class BIGINT(_IntegerType, sqltypes.BIGINT): """MySQL BIGINTEGER type.""" - __visit_name__ = 'BIGINT' + __visit_name__ = "BIGINT" def __init__(self, display_width=None, **kw): """Construct a BIGINTEGER. @@ -278,7 +292,7 @@ class BIGINT(_IntegerType, sqltypes.BIGINT): class MEDIUMINT(_IntegerType): """MySQL MEDIUMINTEGER type.""" - __visit_name__ = 'MEDIUMINT' + __visit_name__ = "MEDIUMINT" def __init__(self, display_width=None, **kw): """Construct a MEDIUMINTEGER @@ -299,7 +313,7 @@ class MEDIUMINT(_IntegerType): class TINYINT(_IntegerType): """MySQL TINYINT type.""" - __visit_name__ = 'TINYINT' + __visit_name__ = "TINYINT" def __init__(self, display_width=None, **kw): """Construct a TINYINT. @@ -320,7 +334,7 @@ class TINYINT(_IntegerType): class SMALLINT(_IntegerType, sqltypes.SMALLINT): """MySQL SMALLINTEGER type.""" - __visit_name__ = 'SMALLINT' + __visit_name__ = "SMALLINT" def __init__(self, display_width=None, **kw): """Construct a SMALLINTEGER. @@ -347,7 +361,7 @@ class BIT(sqltypes.TypeEngine): """ - __visit_name__ = 'BIT' + __visit_name__ = "BIT" def __init__(self, length=None): """Construct a BIT. @@ -374,13 +388,14 @@ class BIT(sqltypes.TypeEngine): v = v << 8 | i return v return value + return process class TIME(sqltypes.TIME): """MySQL TIME type. """ - __visit_name__ = 'TIME' + __visit_name__ = "TIME" def __init__(self, timezone=False, fsp=None): """Construct a MySQL TIME type. @@ -413,12 +428,15 @@ class TIME(sqltypes.TIME): microseconds = value.microseconds seconds = value.seconds minutes = seconds // 60 - return time(minutes // 60, - minutes % 60, - seconds - minutes * 60, - microsecond=microseconds) + return time( + minutes // 60, + minutes % 60, + seconds - minutes * 60, + microsecond=microseconds, + ) else: return None + return process @@ -427,7 +445,7 @@ class TIMESTAMP(sqltypes.TIMESTAMP): """ - __visit_name__ = 'TIMESTAMP' + __visit_name__ = "TIMESTAMP" def __init__(self, timezone=False, fsp=None): """Construct a MySQL TIMESTAMP type. @@ -457,7 +475,7 @@ class DATETIME(sqltypes.DATETIME): """ - __visit_name__ = 'DATETIME' + __visit_name__ = "DATETIME" def __init__(self, timezone=False, fsp=None): """Construct a MySQL DATETIME type. @@ -485,7 +503,7 @@ class DATETIME(sqltypes.DATETIME): class YEAR(sqltypes.TypeEngine): """MySQL YEAR type, for single byte storage of years 1901-2155.""" - __visit_name__ = 'YEAR' + __visit_name__ = "YEAR" def __init__(self, display_width=None): self.display_width = display_width @@ -494,7 +512,7 @@ class YEAR(sqltypes.TypeEngine): class TEXT(_StringType, sqltypes.TEXT): """MySQL TEXT type, for text up to 2^16 characters.""" - __visit_name__ = 'TEXT' + __visit_name__ = "TEXT" def __init__(self, length=None, **kw): """Construct a TEXT. @@ -530,7 +548,7 @@ class TEXT(_StringType, sqltypes.TEXT): class TINYTEXT(_StringType): """MySQL TINYTEXT type, for text up to 2^8 characters.""" - __visit_name__ = 'TINYTEXT' + __visit_name__ = "TINYTEXT" def __init__(self, **kwargs): """Construct a TINYTEXT. @@ -562,7 +580,7 @@ class TINYTEXT(_StringType): class MEDIUMTEXT(_StringType): """MySQL MEDIUMTEXT type, for text up to 2^24 characters.""" - __visit_name__ = 'MEDIUMTEXT' + __visit_name__ = "MEDIUMTEXT" def __init__(self, **kwargs): """Construct a MEDIUMTEXT. @@ -594,7 +612,7 @@ class MEDIUMTEXT(_StringType): class LONGTEXT(_StringType): """MySQL LONGTEXT type, for text up to 2^32 characters.""" - __visit_name__ = 'LONGTEXT' + __visit_name__ = "LONGTEXT" def __init__(self, **kwargs): """Construct a LONGTEXT. @@ -626,7 +644,7 @@ class LONGTEXT(_StringType): class VARCHAR(_StringType, sqltypes.VARCHAR): """MySQL VARCHAR type, for variable-length character data.""" - __visit_name__ = 'VARCHAR' + __visit_name__ = "VARCHAR" def __init__(self, length=None, **kwargs): """Construct a VARCHAR. @@ -658,7 +676,7 @@ class VARCHAR(_StringType, sqltypes.VARCHAR): class CHAR(_StringType, sqltypes.CHAR): """MySQL CHAR type, for fixed-length character data.""" - __visit_name__ = 'CHAR' + __visit_name__ = "CHAR" def __init__(self, length=None, **kwargs): """Construct a CHAR. @@ -690,7 +708,7 @@ class CHAR(_StringType, sqltypes.CHAR): ascii=type_.ascii, binary=type_.binary, unicode=type_.unicode, - national=False # not supported in CAST + national=False, # not supported in CAST ) else: return CHAR(length=type_.length) @@ -703,7 +721,7 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR): character set. """ - __visit_name__ = 'NVARCHAR' + __visit_name__ = "NVARCHAR" def __init__(self, length=None, **kwargs): """Construct an NVARCHAR. @@ -718,7 +736,7 @@ class NVARCHAR(_StringType, sqltypes.NVARCHAR): compatible with the national character set. """ - kwargs['national'] = True + kwargs["national"] = True super(NVARCHAR, self).__init__(length=length, **kwargs) @@ -729,7 +747,7 @@ class NCHAR(_StringType, sqltypes.NCHAR): character set. """ - __visit_name__ = 'NCHAR' + __visit_name__ = "NCHAR" def __init__(self, length=None, **kwargs): """Construct an NCHAR. @@ -744,23 +762,23 @@ class NCHAR(_StringType, sqltypes.NCHAR): compatible with the national character set. """ - kwargs['national'] = True + kwargs["national"] = True super(NCHAR, self).__init__(length=length, **kwargs) class TINYBLOB(sqltypes._Binary): """MySQL TINYBLOB type, for binary data up to 2^8 bytes.""" - __visit_name__ = 'TINYBLOB' + __visit_name__ = "TINYBLOB" class MEDIUMBLOB(sqltypes._Binary): """MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes.""" - __visit_name__ = 'MEDIUMBLOB' + __visit_name__ = "MEDIUMBLOB" class LONGBLOB(sqltypes._Binary): """MySQL LONGBLOB type, for binary data up to 2^32 bytes.""" - __visit_name__ = 'LONGBLOB' + __visit_name__ = "LONGBLOB" diff --git a/lib/sqlalchemy/dialects/mysql/zxjdbc.py b/lib/sqlalchemy/dialects/mysql/zxjdbc.py index 4aee2dbb7..d8ee43748 100644 --- a/lib/sqlalchemy/dialects/mysql/zxjdbc.py +++ b/lib/sqlalchemy/dialects/mysql/zxjdbc.py @@ -37,6 +37,7 @@ from .base import BIT, MySQLDialect, MySQLExecutionContext class _ZxJDBCBit(BIT): def result_processor(self, dialect, coltype): """Converts boolean or byte arrays from MySQL Connector/J to longs.""" + def process(value): if value is None: return value @@ -44,9 +45,10 @@ class _ZxJDBCBit(BIT): return int(value) v = 0 for i in value: - v = v << 8 | (i & 0xff) + v = v << 8 | (i & 0xFF) value = v return value + return process @@ -60,17 +62,13 @@ class MySQLExecutionContext_zxjdbc(MySQLExecutionContext): class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect): - jdbc_db_name = 'mysql' - jdbc_driver_name = 'com.mysql.jdbc.Driver' + jdbc_db_name = "mysql" + jdbc_driver_name = "com.mysql.jdbc.Driver" execution_ctx_cls = MySQLExecutionContext_zxjdbc colspecs = util.update_copy( - MySQLDialect.colspecs, - { - sqltypes.Time: sqltypes.Time, - BIT: _ZxJDBCBit - } + MySQLDialect.colspecs, {sqltypes.Time: sqltypes.Time, BIT: _ZxJDBCBit} ) def _detect_charset(self, connection): @@ -83,17 +81,19 @@ class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect): # this can prefer the driver value. rs = connection.execute("SHOW VARIABLES LIKE 'character_set%%'") opts = {row[0]: row[1] for row in self._compat_fetchall(rs)} - for key in ('character_set_connection', 'character_set'): + for key in ("character_set_connection", "character_set"): if opts.get(key, None): return opts[key] - util.warn("Could not detect the connection character set. " - "Assuming latin1.") - return 'latin1' + util.warn( + "Could not detect the connection character set. " + "Assuming latin1." + ) + return "latin1" def _driver_kwargs(self): """return kw arg dict to be sent to connect().""" - return dict(characterEncoding='UTF-8', yearIsDateType='false') + return dict(characterEncoding="UTF-8", yearIsDateType="false") def _extract_error_code(self, exception): # e.g.: DBAPIError: (Error) Table 'test.u2' doesn't exist @@ -106,7 +106,7 @@ class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect): def _get_server_version_info(self, connection): dbapi_con = connection.connection version = [] - r = re.compile(r'[.\-]') + r = re.compile(r"[.\-]") for n in r.split(dbapi_con.dbversion): try: version.append(int(n)) @@ -114,4 +114,5 @@ class MySQLDialect_zxjdbc(ZxJDBCConnector, MySQLDialect): version.append(n) return tuple(version) + dialect = MySQLDialect_zxjdbc diff --git a/lib/sqlalchemy/dialects/oracle/__init__.py b/lib/sqlalchemy/dialects/oracle/__init__.py index e3d9fed2c..1b9007fcc 100644 --- a/lib/sqlalchemy/dialects/oracle/__init__.py +++ b/lib/sqlalchemy/dialects/oracle/__init__.py @@ -7,18 +7,51 @@ from . import base, cx_oracle, zxjdbc # noqa -from .base import \ - VARCHAR, NVARCHAR, CHAR, DATE, NUMBER,\ - BLOB, BFILE, BINARY_FLOAT, BINARY_DOUBLE, CLOB, NCLOB, TIMESTAMP, RAW,\ - FLOAT, DOUBLE_PRECISION, LONG, INTERVAL,\ - VARCHAR2, NVARCHAR2, ROWID +from .base import ( + VARCHAR, + NVARCHAR, + CHAR, + DATE, + NUMBER, + BLOB, + BFILE, + BINARY_FLOAT, + BINARY_DOUBLE, + CLOB, + NCLOB, + TIMESTAMP, + RAW, + FLOAT, + DOUBLE_PRECISION, + LONG, + INTERVAL, + VARCHAR2, + NVARCHAR2, + ROWID, +) base.dialect = dialect = cx_oracle.dialect __all__ = ( - 'VARCHAR', 'NVARCHAR', 'CHAR', 'DATE', 'NUMBER', - 'BLOB', 'BFILE', 'CLOB', 'NCLOB', 'TIMESTAMP', 'RAW', - 'FLOAT', 'DOUBLE_PRECISION', 'BINARY_DOUBLE', 'BINARY_FLOAT', - 'LONG', 'dialect', 'INTERVAL', - 'VARCHAR2', 'NVARCHAR2', 'ROWID' + "VARCHAR", + "NVARCHAR", + "CHAR", + "DATE", + "NUMBER", + "BLOB", + "BFILE", + "CLOB", + "NCLOB", + "TIMESTAMP", + "RAW", + "FLOAT", + "DOUBLE_PRECISION", + "BINARY_DOUBLE", + "BINARY_FLOAT", + "LONG", + "dialect", + "INTERVAL", + "VARCHAR2", + "NVARCHAR2", + "ROWID", ) diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index b5aea4386..944fe21c3 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -353,49 +353,63 @@ from sqlalchemy.sql import compiler, visitors, expression, util as sql_util from sqlalchemy.sql import operators as sql_operators from sqlalchemy.sql.elements import quoted_name from sqlalchemy import types as sqltypes, schema as sa_schema -from sqlalchemy.types import VARCHAR, NVARCHAR, CHAR, \ - BLOB, CLOB, TIMESTAMP, FLOAT, INTEGER +from sqlalchemy.types import ( + VARCHAR, + NVARCHAR, + CHAR, + BLOB, + CLOB, + TIMESTAMP, + FLOAT, + INTEGER, +) from itertools import groupby -RESERVED_WORDS = \ - set('SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN ' - 'DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED ' - 'ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE ' - 'ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE ' - 'BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES ' - 'AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS ' - 'NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER ' - 'CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR ' - 'DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL'.split()) +RESERVED_WORDS = set( + "SHARE RAW DROP BETWEEN FROM DESC OPTION PRIOR LONG THEN " + "DEFAULT ALTER IS INTO MINUS INTEGER NUMBER GRANT IDENTIFIED " + "ALL TO ORDER ON FLOAT DATE HAVING CLUSTER NOWAIT RESOURCE " + "ANY TABLE INDEX FOR UPDATE WHERE CHECK SMALLINT WITH DELETE " + "BY ASC REVOKE LIKE SIZE RENAME NOCOMPRESS NULL GROUP VALUES " + "AS IN VIEW EXCLUSIVE COMPRESS SYNONYM SELECT INSERT EXISTS " + "NOT TRIGGER ELSE CREATE INTERSECT PCTFREE DISTINCT USER " + "CONNECT SET MODE OF UNIQUE VARCHAR2 VARCHAR LOCK OR CHAR " + "DECIMAL UNION PUBLIC AND START UID COMMENT CURRENT LEVEL".split() +) -NO_ARG_FNS = set('UID CURRENT_DATE SYSDATE USER ' - 'CURRENT_TIME CURRENT_TIMESTAMP'.split()) +NO_ARG_FNS = set( + "UID CURRENT_DATE SYSDATE USER " "CURRENT_TIME CURRENT_TIMESTAMP".split() +) class RAW(sqltypes._Binary): - __visit_name__ = 'RAW' + __visit_name__ = "RAW" + + OracleRaw = RAW class NCLOB(sqltypes.Text): - __visit_name__ = 'NCLOB' + __visit_name__ = "NCLOB" class VARCHAR2(VARCHAR): - __visit_name__ = 'VARCHAR2' + __visit_name__ = "VARCHAR2" + NVARCHAR2 = NVARCHAR class NUMBER(sqltypes.Numeric, sqltypes.Integer): - __visit_name__ = 'NUMBER' + __visit_name__ = "NUMBER" def __init__(self, precision=None, scale=None, asdecimal=None): if asdecimal is None: asdecimal = bool(scale and scale > 0) super(NUMBER, self).__init__( - precision=precision, scale=scale, asdecimal=asdecimal) + precision=precision, scale=scale, asdecimal=asdecimal + ) def adapt(self, impltype): ret = super(NUMBER, self).adapt(impltype) @@ -412,23 +426,23 @@ class NUMBER(sqltypes.Numeric, sqltypes.Integer): class DOUBLE_PRECISION(sqltypes.Float): - __visit_name__ = 'DOUBLE_PRECISION' + __visit_name__ = "DOUBLE_PRECISION" class BINARY_DOUBLE(sqltypes.Float): - __visit_name__ = 'BINARY_DOUBLE' + __visit_name__ = "BINARY_DOUBLE" class BINARY_FLOAT(sqltypes.Float): - __visit_name__ = 'BINARY_FLOAT' + __visit_name__ = "BINARY_FLOAT" class BFILE(sqltypes.LargeBinary): - __visit_name__ = 'BFILE' + __visit_name__ = "BFILE" class LONG(sqltypes.Text): - __visit_name__ = 'LONG' + __visit_name__ = "LONG" class DATE(sqltypes.DateTime): @@ -441,18 +455,17 @@ class DATE(sqltypes.DateTime): .. versionadded:: 0.9.4 """ - __visit_name__ = 'DATE' + + __visit_name__ = "DATE" def _compare_type_affinity(self, other): return other._type_affinity in (sqltypes.DateTime, sqltypes.Date) class INTERVAL(sqltypes.TypeEngine): - __visit_name__ = 'INTERVAL' + __visit_name__ = "INTERVAL" - def __init__(self, - day_precision=None, - second_precision=None): + def __init__(self, day_precision=None, second_precision=None): """Construct an INTERVAL. Note that only DAY TO SECOND intervals are currently supported. @@ -471,8 +484,10 @@ class INTERVAL(sqltypes.TypeEngine): @classmethod def _adapt_from_generic_interval(cls, interval): - return INTERVAL(day_precision=interval.day_precision, - second_precision=interval.second_precision) + return INTERVAL( + day_precision=interval.day_precision, + second_precision=interval.second_precision, + ) @property def _type_affinity(self): @@ -485,38 +500,40 @@ class ROWID(sqltypes.TypeEngine): When used in a cast() or similar, generates ROWID. """ - __visit_name__ = 'ROWID' + + __visit_name__ = "ROWID" class _OracleBoolean(sqltypes.Boolean): def get_dbapi_type(self, dbapi): return dbapi.NUMBER + colspecs = { sqltypes.Boolean: _OracleBoolean, sqltypes.Interval: INTERVAL, - sqltypes.DateTime: DATE + sqltypes.DateTime: DATE, } ischema_names = { - 'VARCHAR2': VARCHAR, - 'NVARCHAR2': NVARCHAR, - 'CHAR': CHAR, - 'DATE': DATE, - 'NUMBER': NUMBER, - 'BLOB': BLOB, - 'BFILE': BFILE, - 'CLOB': CLOB, - 'NCLOB': NCLOB, - 'TIMESTAMP': TIMESTAMP, - 'TIMESTAMP WITH TIME ZONE': TIMESTAMP, - 'INTERVAL DAY TO SECOND': INTERVAL, - 'RAW': RAW, - 'FLOAT': FLOAT, - 'DOUBLE PRECISION': DOUBLE_PRECISION, - 'LONG': LONG, - 'BINARY_DOUBLE': BINARY_DOUBLE, - 'BINARY_FLOAT': BINARY_FLOAT + "VARCHAR2": VARCHAR, + "NVARCHAR2": NVARCHAR, + "CHAR": CHAR, + "DATE": DATE, + "NUMBER": NUMBER, + "BLOB": BLOB, + "BFILE": BFILE, + "CLOB": CLOB, + "NCLOB": NCLOB, + "TIMESTAMP": TIMESTAMP, + "TIMESTAMP WITH TIME ZONE": TIMESTAMP, + "INTERVAL DAY TO SECOND": INTERVAL, + "RAW": RAW, + "FLOAT": FLOAT, + "DOUBLE PRECISION": DOUBLE_PRECISION, + "LONG": LONG, + "BINARY_DOUBLE": BINARY_DOUBLE, + "BINARY_FLOAT": BINARY_FLOAT, } @@ -540,12 +557,12 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): def visit_INTERVAL(self, type_, **kw): return "INTERVAL DAY%s TO SECOND%s" % ( - type_.day_precision is not None and - "(%d)" % type_.day_precision or - "", - type_.second_precision is not None and - "(%d)" % type_.second_precision or - "", + type_.day_precision is not None + and "(%d)" % type_.day_precision + or "", + type_.second_precision is not None + and "(%d)" % type_.second_precision + or "", ) def visit_LONG(self, type_, **kw): @@ -569,52 +586,53 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): def visit_FLOAT(self, type_, **kw): # don't support conversion between decimal/binary # precision yet - kw['no_precision'] = True + kw["no_precision"] = True return self._generate_numeric(type_, "FLOAT", **kw) def visit_NUMBER(self, type_, **kw): return self._generate_numeric(type_, "NUMBER", **kw) def _generate_numeric( - self, type_, name, precision=None, - scale=None, no_precision=False, **kw): + self, type_, name, precision=None, scale=None, no_precision=False, **kw + ): if precision is None: precision = type_.precision if scale is None: - scale = getattr(type_, 'scale', None) + scale = getattr(type_, "scale", None) if no_precision or precision is None: return name elif scale is None: n = "%(name)s(%(precision)s)" - return n % {'name': name, 'precision': precision} + return n % {"name": name, "precision": precision} else: n = "%(name)s(%(precision)s, %(scale)s)" - return n % {'name': name, 'precision': precision, 'scale': scale} + return n % {"name": name, "precision": precision, "scale": scale} def visit_string(self, type_, **kw): return self.visit_VARCHAR2(type_, **kw) def visit_VARCHAR2(self, type_, **kw): - return self._visit_varchar(type_, '', '2') + return self._visit_varchar(type_, "", "2") def visit_NVARCHAR2(self, type_, **kw): - return self._visit_varchar(type_, 'N', '2') + return self._visit_varchar(type_, "N", "2") + visit_NVARCHAR = visit_NVARCHAR2 def visit_VARCHAR(self, type_, **kw): - return self._visit_varchar(type_, '', '') + return self._visit_varchar(type_, "", "") def _visit_varchar(self, type_, n, num): if not type_.length: - return "%(n)sVARCHAR%(two)s" % {'two': num, 'n': n} + return "%(n)sVARCHAR%(two)s" % {"two": num, "n": n} elif not n and self.dialect._supports_char_length: varchar = "VARCHAR%(two)s(%(length)s CHAR)" - return varchar % {'length': type_.length, 'two': num} + return varchar % {"length": type_.length, "two": num} else: varchar = "%(n)sVARCHAR%(two)s(%(length)s)" - return varchar % {'length': type_.length, 'two': num, 'n': n} + return varchar % {"length": type_.length, "two": num, "n": n} def visit_text(self, type_, **kw): return self.visit_CLOB(type_, **kw) @@ -636,7 +654,7 @@ class OracleTypeCompiler(compiler.GenericTypeCompiler): def visit_RAW(self, type_, **kw): if type_.length: - return "RAW(%(length)s)" % {'length': type_.length} + return "RAW(%(length)s)" % {"length": type_.length} else: return "RAW" @@ -652,9 +670,7 @@ class OracleCompiler(compiler.SQLCompiler): compound_keywords = util.update_copy( compiler.SQLCompiler.compound_keywords, - { - expression.CompoundSelect.EXCEPT: 'MINUS' - } + {expression.CompoundSelect.EXCEPT: "MINUS"}, ) def __init__(self, *args, **kwargs): @@ -663,8 +679,10 @@ class OracleCompiler(compiler.SQLCompiler): super(OracleCompiler, self).__init__(*args, **kwargs) def visit_mod_binary(self, binary, operator, **kw): - return "mod(%s, %s)" % (self.process(binary.left, **kw), - self.process(binary.right, **kw)) + return "mod(%s, %s)" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) def visit_now_func(self, fn, **kw): return "CURRENT_TIMESTAMP" @@ -673,22 +691,22 @@ class OracleCompiler(compiler.SQLCompiler): return "LENGTH" + self.function_argspec(fn, **kw) def visit_match_op_binary(self, binary, operator, **kw): - return "CONTAINS (%s, %s)" % (self.process(binary.left), - self.process(binary.right)) + return "CONTAINS (%s, %s)" % ( + self.process(binary.left), + self.process(binary.right), + ) def visit_true(self, expr, **kw): - return '1' + return "1" def visit_false(self, expr, **kw): - return '0' + return "0" def get_cte_preamble(self, recursive): return "WITH" def get_select_hint_text(self, byfroms): - return " ".join( - "/*+ %s */" % text for table, text in byfroms.items() - ) + return " ".join("/*+ %s */" % text for table, text in byfroms.items()) def function_argspec(self, fn, **kw): if len(fn.clauses) > 0 or fn.name.upper() not in NO_ARG_FNS: @@ -709,13 +727,16 @@ class OracleCompiler(compiler.SQLCompiler): if self.dialect.use_ansi: return compiler.SQLCompiler.visit_join(self, join, **kwargs) else: - kwargs['asfrom'] = True + kwargs["asfrom"] = True if isinstance(join.right, expression.FromGrouping): right = join.right.element else: right = join.right - return self.process(join.left, **kwargs) + \ - ", " + self.process(right, **kwargs) + return ( + self.process(join.left, **kwargs) + + ", " + + self.process(right, **kwargs) + ) def _get_nonansi_join_whereclause(self, froms): clauses = [] @@ -727,14 +748,20 @@ class OracleCompiler(compiler.SQLCompiler): # the join condition in the WHERE clause" - that is, # unconditionally regardless of operator or the other side def visit_binary(binary): - if isinstance(binary.left, expression.ColumnClause) \ - and join.right.is_derived_from(binary.left.table): + if isinstance( + binary.left, expression.ColumnClause + ) and join.right.is_derived_from(binary.left.table): binary.left = _OuterJoinColumn(binary.left) - elif isinstance(binary.right, expression.ColumnClause) \ - and join.right.is_derived_from(binary.right.table): + elif isinstance( + binary.right, expression.ColumnClause + ) and join.right.is_derived_from(binary.right.table): binary.right = _OuterJoinColumn(binary.right) - clauses.append(visitors.cloned_traverse( - join.onclause, {}, {'binary': visit_binary})) + + clauses.append( + visitors.cloned_traverse( + join.onclause, {}, {"binary": visit_binary} + ) + ) else: clauses.append(join.onclause) @@ -757,8 +784,9 @@ class OracleCompiler(compiler.SQLCompiler): return self.process(vc.column, **kw) + "(+)" def visit_sequence(self, seq, **kw): - return (self.dialect.identifier_preparer.format_sequence(seq) + - ".nextval") + return ( + self.dialect.identifier_preparer.format_sequence(seq) + ".nextval" + ) def get_render_as_alias_suffix(self, alias_name_text): """Oracle doesn't like ``FROM table AS alias``""" @@ -770,7 +798,8 @@ class OracleCompiler(compiler.SQLCompiler): binds = [] for i, column in enumerate( - expression._select_iterables(returning_cols)): + expression._select_iterables(returning_cols) + ): if column.type._has_column_expression: col_expr = column.type.column_expression(column) else: @@ -779,19 +808,22 @@ class OracleCompiler(compiler.SQLCompiler): outparam = sql.outparam("ret_%d" % i, type_=column.type) self.binds[outparam.key] = outparam binds.append( - self.bindparam_string(self._truncate_bindparam(outparam))) - columns.append( - self.process(col_expr, within_columns_clause=False)) + self.bindparam_string(self._truncate_bindparam(outparam)) + ) + columns.append(self.process(col_expr, within_columns_clause=False)) self._add_to_result_map( - getattr(col_expr, 'name', col_expr.anon_label), - getattr(col_expr, 'name', col_expr.anon_label), - (column, getattr(column, 'name', None), - getattr(column, 'key', None)), - column.type + getattr(col_expr, "name", col_expr.anon_label), + getattr(col_expr, "name", col_expr.anon_label), + ( + column, + getattr(column, "name", None), + getattr(column, "key", None), + ), + column.type, ) - return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) + return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds) def _TODO_visit_compound_select(self, select): """Need to determine how to get ``LIMIT``/``OFFSET`` into a @@ -804,10 +836,11 @@ class OracleCompiler(compiler.SQLCompiler): so tries to wrap it in a subquery with ``rownum`` criterion. """ - if not getattr(select, '_oracle_visit', None): + if not getattr(select, "_oracle_visit", None): if not self.dialect.use_ansi: froms = self._display_froms_for_select( - select, kwargs.get('asfrom', False)) + select, kwargs.get("asfrom", False) + ) whereclause = self._get_nonansi_join_whereclause(froms) if whereclause is not None: select = select.where(whereclause) @@ -828,18 +861,20 @@ class OracleCompiler(compiler.SQLCompiler): # Outer select and "ROWNUM as ora_rn" can be dropped if # limit=0 - kwargs['select_wraps_for'] = select + kwargs["select_wraps_for"] = select select = select._generate() select._oracle_visit = True # Wrap the middle select and add the hint limitselect = sql.select([c for c in select.c]) - if limit_clause is not None and \ - self.dialect.optimize_limits and \ - select._simple_int_limit: + if ( + limit_clause is not None + and self.dialect.optimize_limits + and select._simple_int_limit + ): limitselect = limitselect.prefix_with( - "/*+ FIRST_ROWS(%d) */" % - select._limit) + "/*+ FIRST_ROWS(%d) */" % select._limit + ) limitselect._oracle_visit = True limitselect._is_wrapper = True @@ -855,8 +890,8 @@ class OracleCompiler(compiler.SQLCompiler): adapter = sql_util.ClauseAdapter(select) for_update.of = [ - adapter.traverse(elem) - for elem in for_update.of] + adapter.traverse(elem) for elem in for_update.of + ] # If needed, add the limiting clause if limit_clause is not None: @@ -873,7 +908,8 @@ class OracleCompiler(compiler.SQLCompiler): if offset_clause is not None: max_row = max_row + offset_clause limitselect.append_whereclause( - sql.literal_column("ROWNUM") <= max_row) + sql.literal_column("ROWNUM") <= max_row + ) # If needed, add the ora_rn, and wrap again with offset. if offset_clause is None: @@ -881,12 +917,14 @@ class OracleCompiler(compiler.SQLCompiler): select = limitselect else: limitselect = limitselect.column( - sql.literal_column("ROWNUM").label("ora_rn")) + sql.literal_column("ROWNUM").label("ora_rn") + ) limitselect._oracle_visit = True limitselect._is_wrapper = True offsetselect = sql.select( - [c for c in limitselect.c if c.key != 'ora_rn']) + [c for c in limitselect.c if c.key != "ora_rn"] + ) offsetselect._oracle_visit = True offsetselect._is_wrapper = True @@ -897,9 +935,11 @@ class OracleCompiler(compiler.SQLCompiler): if not self.dialect.use_binds_for_limits: offset_clause = sql.literal_column( - "%d" % select._offset) + "%d" % select._offset + ) offsetselect.append_whereclause( - sql.literal_column("ora_rn") > offset_clause) + sql.literal_column("ora_rn") > offset_clause + ) offsetselect._for_update_arg = for_update select = offsetselect @@ -910,18 +950,17 @@ class OracleCompiler(compiler.SQLCompiler): return "" def visit_empty_set_expr(self, type_): - return 'SELECT 1 FROM DUAL WHERE 1!=1' + return "SELECT 1 FROM DUAL WHERE 1!=1" def for_update_clause(self, select, **kw): if self.is_subquery(): return "" - tmp = ' FOR UPDATE' + tmp = " FOR UPDATE" if select._for_update_arg.of: - tmp += ' OF ' + ', '.join( - self.process(elem, **kw) for elem in - select._for_update_arg.of + tmp += " OF " + ", ".join( + self.process(elem, **kw) for elem in select._for_update_arg.of ) if select._for_update_arg.nowait: @@ -933,7 +972,6 @@ class OracleCompiler(compiler.SQLCompiler): class OracleDDLCompiler(compiler.DDLCompiler): - def define_constraint_cascades(self, constraint): text = "" if constraint.ondelete is not None: @@ -947,7 +985,8 @@ class OracleDDLCompiler(compiler.DDLCompiler): "Oracle does not contain native UPDATE CASCADE " "functionality - onupdates will not be rendered for foreign " "keys. Consider using deferrable=True, initially='deferred' " - "or triggers.") + "or triggers." + ) return text @@ -958,75 +997,79 @@ class OracleDDLCompiler(compiler.DDLCompiler): text = "CREATE " if index.unique: text += "UNIQUE " - if index.dialect_options['oracle']['bitmap']: + if index.dialect_options["oracle"]["bitmap"]: text += "BITMAP " text += "INDEX %s ON %s (%s)" % ( self._prepared_index_name(index, include_schema=True), preparer.format_table(index.table, use_schema=True), - ', '.join( + ", ".join( self.sql_compiler.process( - expr, - include_table=False, literal_binds=True) - for expr in index.expressions) + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), ) - if index.dialect_options['oracle']['compress'] is not False: - if index.dialect_options['oracle']['compress'] is True: + if index.dialect_options["oracle"]["compress"] is not False: + if index.dialect_options["oracle"]["compress"] is True: text += " COMPRESS" else: text += " COMPRESS %d" % ( - index.dialect_options['oracle']['compress'] + index.dialect_options["oracle"]["compress"] ) return text def post_create_table(self, table): table_opts = [] - opts = table.dialect_options['oracle'] + opts = table.dialect_options["oracle"] - if opts['on_commit']: - on_commit_options = opts['on_commit'].replace("_", " ").upper() - table_opts.append('\n ON COMMIT %s' % on_commit_options) + if opts["on_commit"]: + on_commit_options = opts["on_commit"].replace("_", " ").upper() + table_opts.append("\n ON COMMIT %s" % on_commit_options) - if opts['compress']: - if opts['compress'] is True: + if opts["compress"]: + if opts["compress"] is True: table_opts.append("\n COMPRESS") else: - table_opts.append("\n COMPRESS FOR %s" % ( - opts['compress'] - )) + table_opts.append("\n COMPRESS FOR %s" % (opts["compress"])) - return ''.join(table_opts) + return "".join(table_opts) class OracleIdentifierPreparer(compiler.IdentifierPreparer): reserved_words = {x.lower() for x in RESERVED_WORDS} - illegal_initial_characters = {str(dig) for dig in range(0, 10)} \ - .union(["_", "$"]) + illegal_initial_characters = {str(dig) for dig in range(0, 10)}.union( + ["_", "$"] + ) def _bindparam_requires_quotes(self, value): """Return True if the given identifier requires quoting.""" lc_value = value.lower() - return (lc_value in self.reserved_words - or value[0] in self.illegal_initial_characters - or not self.legal_characters.match(util.text_type(value)) - ) + return ( + lc_value in self.reserved_words + or value[0] in self.illegal_initial_characters + or not self.legal_characters.match(util.text_type(value)) + ) def format_savepoint(self, savepoint): - name = savepoint.ident.lstrip('_') - return super( - OracleIdentifierPreparer, self).format_savepoint(savepoint, name) + name = savepoint.ident.lstrip("_") + return super(OracleIdentifierPreparer, self).format_savepoint( + savepoint, name + ) class OracleExecutionContext(default.DefaultExecutionContext): def fire_sequence(self, seq, type_): return self._execute_scalar( - "SELECT " + - self.dialect.identifier_preparer.format_sequence(seq) + - ".nextval FROM DUAL", type_) + "SELECT " + + self.dialect.identifier_preparer.format_sequence(seq) + + ".nextval FROM DUAL", + type_, + ) class OracleDialect(default.DefaultDialect): - name = 'oracle' + name = "oracle" supports_alter = True supports_unicode_statements = False supports_unicode_binds = False @@ -1039,7 +1082,7 @@ class OracleDialect(default.DefaultDialect): sequences_optional = False postfetch_lastrowid = False - default_paramstyle = 'named' + default_paramstyle = "named" colspecs = colspecs ischema_names = ischema_names requires_name_normalize = True @@ -1054,29 +1097,27 @@ class OracleDialect(default.DefaultDialect): preparer = OracleIdentifierPreparer execution_ctx_cls = OracleExecutionContext - reflection_options = ('oracle_resolve_synonyms', ) + reflection_options = ("oracle_resolve_synonyms",) _use_nchar_for_unicode = False construct_arguments = [ - (sa_schema.Table, { - "resolve_synonyms": False, - "on_commit": None, - "compress": False - }), - (sa_schema.Index, { - "bitmap": False, - "compress": False - }) + ( + sa_schema.Table, + {"resolve_synonyms": False, "on_commit": None, "compress": False}, + ), + (sa_schema.Index, {"bitmap": False, "compress": False}), ] - def __init__(self, - use_ansi=True, - optimize_limits=False, - use_binds_for_limits=True, - use_nchar_for_unicode=False, - exclude_tablespaces=('SYSTEM', 'SYSAUX', ), - **kwargs): + def __init__( + self, + use_ansi=True, + optimize_limits=False, + use_binds_for_limits=True, + use_nchar_for_unicode=False, + exclude_tablespaces=("SYSTEM", "SYSAUX"), + **kwargs + ): default.DefaultDialect.__init__(self, **kwargs) self._use_nchar_for_unicode = use_nchar_for_unicode self.use_ansi = use_ansi @@ -1087,8 +1128,7 @@ class OracleDialect(default.DefaultDialect): def initialize(self, connection): super(OracleDialect, self).initialize(connection) self.implicit_returning = self.__dict__.get( - 'implicit_returning', - self.server_version_info > (10, ) + "implicit_returning", self.server_version_info > (10,) ) if self._is_oracle_8: @@ -1098,18 +1138,15 @@ class OracleDialect(default.DefaultDialect): @property def _is_oracle_8(self): - return self.server_version_info and \ - self.server_version_info < (9, ) + return self.server_version_info and self.server_version_info < (9,) @property def _supports_table_compression(self): - return self.server_version_info and \ - self.server_version_info >= (10, 1, ) + return self.server_version_info and self.server_version_info >= (10, 1) @property def _supports_table_compress_for(self): - return self.server_version_info and \ - self.server_version_info >= (11, ) + return self.server_version_info and self.server_version_info >= (11,) @property def _supports_char_length(self): @@ -1123,31 +1160,38 @@ class OracleDialect(default.DefaultDialect): additional_tests = [ expression.cast( expression.literal_column("'test nvarchar2 returns'"), - sqltypes.NVARCHAR(60) - ), + sqltypes.NVARCHAR(60), + ) ] return super(OracleDialect, self)._check_unicode_returns( - connection, additional_tests) + connection, additional_tests + ) def has_table(self, connection, table_name, schema=None): if not schema: schema = self.default_schema_name cursor = connection.execute( - sql.text("SELECT table_name FROM all_tables " - "WHERE table_name = :name AND owner = :schema_name"), + sql.text( + "SELECT table_name FROM all_tables " + "WHERE table_name = :name AND owner = :schema_name" + ), name=self.denormalize_name(table_name), - schema_name=self.denormalize_name(schema)) + schema_name=self.denormalize_name(schema), + ) return cursor.first() is not None def has_sequence(self, connection, sequence_name, schema=None): if not schema: schema = self.default_schema_name cursor = connection.execute( - sql.text("SELECT sequence_name FROM all_sequences " - "WHERE sequence_name = :name AND " - "sequence_owner = :schema_name"), + sql.text( + "SELECT sequence_name FROM all_sequences " + "WHERE sequence_name = :name AND " + "sequence_owner = :schema_name" + ), name=self.denormalize_name(sequence_name), - schema_name=self.denormalize_name(schema)) + schema_name=self.denormalize_name(schema), + ) return cursor.first() is not None def normalize_name(self, name): @@ -1156,8 +1200,9 @@ class OracleDialect(default.DefaultDialect): if util.py2k: if isinstance(name, str): name = name.decode(self.encoding) - if name.upper() == name and not \ - self.identifier_preparer._requires_quotes(name.lower()): + if name.upper() == name and not self.identifier_preparer._requires_quotes( + name.lower() + ): return name.lower() elif name.lower() == name: return quoted_name(name, quote=True) @@ -1167,8 +1212,9 @@ class OracleDialect(default.DefaultDialect): def denormalize_name(self, name): if name is None: return None - elif name.lower() == name and not \ - self.identifier_preparer._requires_quotes(name.lower()): + elif name.lower() == name and not self.identifier_preparer._requires_quotes( + name.lower() + ): name = name.upper() if util.py2k: if not self.supports_unicode_binds: @@ -1179,10 +1225,16 @@ class OracleDialect(default.DefaultDialect): def _get_default_schema_name(self, connection): return self.normalize_name( - connection.execute('SELECT USER FROM DUAL').scalar()) + connection.execute("SELECT USER FROM DUAL").scalar() + ) - def _resolve_synonym(self, connection, desired_owner=None, - desired_synonym=None, desired_table=None): + def _resolve_synonym( + self, + connection, + desired_owner=None, + desired_synonym=None, + desired_table=None, + ): """search for a local synonym matching the given desired owner/name. if desired_owner is None, attempts to locate a distinct owner. @@ -1191,19 +1243,21 @@ class OracleDialect(default.DefaultDialect): found. """ - q = "SELECT owner, table_owner, table_name, db_link, "\ + q = ( + "SELECT owner, table_owner, table_name, db_link, " "synonym_name FROM all_synonyms WHERE " + ) clauses = [] params = {} if desired_synonym: clauses.append("synonym_name = :synonym_name") - params['synonym_name'] = desired_synonym + params["synonym_name"] = desired_synonym if desired_owner: clauses.append("owner = :desired_owner") - params['desired_owner'] = desired_owner + params["desired_owner"] = desired_owner if desired_table: clauses.append("table_name = :tname") - params['tname'] = desired_table + params["tname"] = desired_table q += " AND ".join(clauses) @@ -1211,8 +1265,12 @@ class OracleDialect(default.DefaultDialect): if desired_owner: row = result.first() if row: - return (row['table_name'], row['table_owner'], - row['db_link'], row['synonym_name']) + return ( + row["table_name"], + row["table_owner"], + row["db_link"], + row["synonym_name"], + ) else: return None, None, None, None else: @@ -1220,23 +1278,35 @@ class OracleDialect(default.DefaultDialect): if len(rows) > 1: raise AssertionError( "There are multiple tables visible to the schema, you " - "must specify owner") + "must specify owner" + ) elif len(rows) == 1: row = rows[0] - return (row['table_name'], row['table_owner'], - row['db_link'], row['synonym_name']) + return ( + row["table_name"], + row["table_owner"], + row["db_link"], + row["synonym_name"], + ) else: return None, None, None, None @reflection.cache - def _prepare_reflection_args(self, connection, table_name, schema=None, - resolve_synonyms=False, dblink='', **kw): + def _prepare_reflection_args( + self, + connection, + table_name, + schema=None, + resolve_synonyms=False, + dblink="", + **kw + ): if resolve_synonyms: actual_name, owner, dblink, synonym = self._resolve_synonym( connection, desired_owner=self.denormalize_name(schema), - desired_synonym=self.denormalize_name(table_name) + desired_synonym=self.denormalize_name(table_name), ) else: actual_name, owner, dblink, synonym = None, None, None, None @@ -1250,18 +1320,21 @@ class OracleDialect(default.DefaultDialect): # will need to hear from more users if we are doing # the right thing here. See [ticket:2619] owner = connection.scalar( - sql.text("SELECT username FROM user_db_links " - "WHERE db_link=:link"), link=dblink) + sql.text( + "SELECT username FROM user_db_links " "WHERE db_link=:link" + ), + link=dblink, + ) dblink = "@" + dblink elif not owner: owner = self.denormalize_name(schema or self.default_schema_name) - return (actual_name, owner, dblink or '', synonym) + return (actual_name, owner, dblink or "", synonym) @reflection.cache def get_schema_names(self, connection, **kw): s = "SELECT username FROM all_users ORDER BY username" - cursor = connection.execute(s,) + cursor = connection.execute(s) return [self.normalize_name(row[0]) for row in cursor] @reflection.cache @@ -1276,14 +1349,12 @@ class OracleDialect(default.DefaultDialect): if self.exclude_tablespaces: sql_str += ( "nvl(tablespace_name, 'no tablespace') " - "NOT IN (%s) AND " % ( - ', '.join(["'%s'" % ts for ts in self.exclude_tablespaces]) - ) + "NOT IN (%s) AND " + % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces])) ) sql_str += ( - "OWNER = :owner " - "AND IOT_NAME IS NULL " - "AND DURATION IS NULL") + "OWNER = :owner " "AND IOT_NAME IS NULL " "AND DURATION IS NULL" + ) cursor = connection.execute(sql.text(sql_str), owner=schema) return [self.normalize_name(row[0]) for row in cursor] @@ -1296,14 +1367,14 @@ class OracleDialect(default.DefaultDialect): if self.exclude_tablespaces: sql_str += ( "nvl(tablespace_name, 'no tablespace') " - "NOT IN (%s) AND " % ( - ', '.join(["'%s'" % ts for ts in self.exclude_tablespaces]) - ) + "NOT IN (%s) AND " + % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces])) ) sql_str += ( "OWNER = :owner " "AND IOT_NAME IS NULL " - "AND DURATION IS NOT NULL") + "AND DURATION IS NOT NULL" + ) cursor = connection.execute(sql.text(sql_str), owner=schema) return [self.normalize_name(row[0]) for row in cursor] @@ -1319,14 +1390,18 @@ class OracleDialect(default.DefaultDialect): def get_table_options(self, connection, table_name, schema=None, **kw): options = {} - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) params = {"table_name": table_name} @@ -1336,14 +1411,16 @@ class OracleDialect(default.DefaultDialect): if self._supports_table_compress_for: columns.append("compress_for") - text = "SELECT %(columns)s "\ - "FROM ALL_TABLES%(dblink)s "\ + text = ( + "SELECT %(columns)s " + "FROM ALL_TABLES%(dblink)s " "WHERE table_name = :table_name" + ) if schema is not None: - params['owner'] = schema + params["owner"] = schema text += " AND owner = :owner " - text = text % {'dblink': dblink, 'columns': ", ".join(columns)} + text = text % {"dblink": dblink, "columns": ", ".join(columns)} result = connection.execute(sql.text(text), **params) @@ -1353,9 +1430,9 @@ class OracleDialect(default.DefaultDialect): if row: if "compression" in row and enabled.get(row.compression, False): if "compress_for" in row: - options['oracle_compress'] = row.compress_for + options["oracle_compress"] = row.compress_for else: - options['oracle_compress'] = True + options["oracle_compress"] = True return options @@ -1371,19 +1448,23 @@ class OracleDialect(default.DefaultDialect): """ - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) columns = [] if self._supports_char_length: - char_length_col = 'char_length' + char_length_col = "char_length" else: - char_length_col = 'data_length' + char_length_col = "data_length" params = {"table_name": table_name} text = """ @@ -1398,10 +1479,10 @@ class OracleDialect(default.DefaultDialect): WHERE col.table_name = :table_name """ if schema is not None: - params['owner'] = schema + params["owner"] = schema text += " AND col.owner = :owner " text += " ORDER BY col.column_id" - text = text % {'dblink': dblink, 'char_length_col': char_length_col} + text = text % {"dblink": dblink, "char_length_col": char_length_col} c = connection.execute(sql.text(text), **params) @@ -1412,54 +1493,67 @@ class OracleDialect(default.DefaultDialect): length = row[2] precision = row[3] scale = row[4] - nullable = row[5] == 'Y' + nullable = row[5] == "Y" default = row[6] comment = row[7] - if coltype == 'NUMBER': + if coltype == "NUMBER": if precision is None and scale == 0: coltype = INTEGER() else: coltype = NUMBER(precision, scale) - elif coltype == 'FLOAT': + elif coltype == "FLOAT": # TODO: support "precision" here as "binary_precision" coltype = FLOAT() - elif coltype in ('VARCHAR2', 'NVARCHAR2', 'CHAR'): + elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR"): coltype = self.ischema_names.get(coltype)(length) - elif 'WITH TIME ZONE' in coltype: + elif "WITH TIME ZONE" in coltype: coltype = TIMESTAMP(timezone=True) else: - coltype = re.sub(r'\(\d+\)', '', coltype) + coltype = re.sub(r"\(\d+\)", "", coltype) try: coltype = self.ischema_names[coltype] except KeyError: - util.warn("Did not recognize type '%s' of column '%s'" % - (coltype, colname)) + util.warn( + "Did not recognize type '%s' of column '%s'" + % (coltype, colname) + ) coltype = sqltypes.NULLTYPE cdict = { - 'name': colname, - 'type': coltype, - 'nullable': nullable, - 'default': default, - 'autoincrement': 'auto', - 'comment': comment, + "name": colname, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": "auto", + "comment": comment, } if orig_colname.lower() == orig_colname: - cdict['quote'] = True + cdict["quote"] = True columns.append(cdict) return columns @reflection.cache - def get_table_comment(self, connection, table_name, schema=None, - resolve_synonyms=False, dblink='', **kw): - - info_cache = kw.get('info_cache') - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + def get_table_comment( + self, + connection, + table_name, + schema=None, + resolve_synonyms=False, + dblink="", + **kw + ): + + info_cache = kw.get("info_cache") + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) COMMENT_SQL = """ SELECT comments @@ -1471,67 +1565,90 @@ class OracleDialect(default.DefaultDialect): return {"text": c.scalar()} @reflection.cache - def get_indexes(self, connection, table_name, schema=None, - resolve_synonyms=False, dblink='', **kw): - - info_cache = kw.get('info_cache') - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + def get_indexes( + self, + connection, + table_name, + schema=None, + resolve_synonyms=False, + dblink="", + **kw + ): + + info_cache = kw.get("info_cache") + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) indexes = [] - params = {'table_name': table_name} - text = \ - "SELECT a.index_name, a.column_name, "\ - "\nb.index_type, b.uniqueness, b.compression, b.prefix_length "\ - "\nFROM ALL_IND_COLUMNS%(dblink)s a, "\ - "\nALL_INDEXES%(dblink)s b "\ - "\nWHERE "\ - "\na.index_name = b.index_name "\ - "\nAND a.table_owner = b.table_owner "\ - "\nAND a.table_name = b.table_name "\ + params = {"table_name": table_name} + text = ( + "SELECT a.index_name, a.column_name, " + "\nb.index_type, b.uniqueness, b.compression, b.prefix_length " + "\nFROM ALL_IND_COLUMNS%(dblink)s a, " + "\nALL_INDEXES%(dblink)s b " + "\nWHERE " + "\na.index_name = b.index_name " + "\nAND a.table_owner = b.table_owner " + "\nAND a.table_name = b.table_name " "\nAND a.table_name = :table_name " + ) if schema is not None: - params['schema'] = schema + params["schema"] = schema text += "AND a.table_owner = :schema " text += "ORDER BY a.index_name, a.column_position" - text = text % {'dblink': dblink} + text = text % {"dblink": dblink} q = sql.text(text) rp = connection.execute(q, **params) indexes = [] last_index_name = None pk_constraint = self.get_pk_constraint( - connection, table_name, schema, resolve_synonyms=resolve_synonyms, - dblink=dblink, info_cache=kw.get('info_cache')) - pkeys = pk_constraint['constrained_columns'] + connection, + table_name, + schema, + resolve_synonyms=resolve_synonyms, + dblink=dblink, + info_cache=kw.get("info_cache"), + ) + pkeys = pk_constraint["constrained_columns"] uniqueness = dict(NONUNIQUE=False, UNIQUE=True) enabled = dict(DISABLED=False, ENABLED=True) - oracle_sys_col = re.compile(r'SYS_NC\d+\$', re.IGNORECASE) + oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE) index = None for rset in rp: if rset.index_name != last_index_name: - index = dict(name=self.normalize_name(rset.index_name), - column_names=[], dialect_options={}) + index = dict( + name=self.normalize_name(rset.index_name), + column_names=[], + dialect_options={}, + ) indexes.append(index) - index['unique'] = uniqueness.get(rset.uniqueness, False) + index["unique"] = uniqueness.get(rset.uniqueness, False) - if rset.index_type in ('BITMAP', 'FUNCTION-BASED BITMAP'): - index['dialect_options']['oracle_bitmap'] = True + if rset.index_type in ("BITMAP", "FUNCTION-BASED BITMAP"): + index["dialect_options"]["oracle_bitmap"] = True if enabled.get(rset.compression, False): - index['dialect_options']['oracle_compress'] = rset.prefix_length + index["dialect_options"][ + "oracle_compress" + ] = rset.prefix_length # filter out Oracle SYS_NC names. could also do an outer join # to the all_tab_columns table and check for real col names there. if not oracle_sys_col.match(rset.column_name): - index['column_names'].append( - self.normalize_name(rset.column_name)) + index["column_names"].append( + self.normalize_name(rset.column_name) + ) last_index_name = rset.index_name def upper_name_set(names): @@ -1539,18 +1656,21 @@ class OracleDialect(default.DefaultDialect): pk_names = upper_name_set(pkeys) if pk_names: + def is_pk_index(index): # don't include the primary key index - return upper_name_set(index['column_names']) == pk_names + return upper_name_set(index["column_names"]) == pk_names + indexes = [idx for idx in indexes if not is_pk_index(idx)] return indexes @reflection.cache - def _get_constraint_data(self, connection, table_name, schema=None, - dblink='', **kw): + def _get_constraint_data( + self, connection, table_name, schema=None, dblink="", **kw + ): - params = {'table_name': table_name} + params = {"table_name": table_name} text = ( "SELECT" @@ -1572,7 +1692,7 @@ class OracleDialect(default.DefaultDialect): ) if schema is not None: - params['owner'] = schema + params["owner"] = schema text += "\nAND ac.owner = :owner" text += ( @@ -1584,35 +1704,49 @@ class OracleDialect(default.DefaultDialect): "\nORDER BY ac.constraint_name, loc.position" ) - text = text % {'dblink': dblink} + text = text % {"dblink": dblink} rp = connection.execute(sql.text(text), **params) constraint_data = rp.fetchall() return constraint_data @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) pkeys = [] constraint_name = None constraint_data = self._get_constraint_data( - connection, table_name, schema, dblink, - info_cache=kw.get('info_cache')) + connection, + table_name, + schema, + dblink, + info_cache=kw.get("info_cache"), + ) for row in constraint_data: - (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ - row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) - if cons_type == 'P': + ( + cons_name, + cons_type, + local_column, + remote_table, + remote_column, + remote_owner, + ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + if cons_type == "P": if constraint_name is None: constraint_name = self.normalize_name(cons_name) pkeys.append(local_column) - return {'constrained_columns': pkeys, 'name': constraint_name} + return {"constrained_columns": pkeys, "name": constraint_name} @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): @@ -1626,74 +1760,94 @@ class OracleDialect(default.DefaultDialect): """ requested_schema = schema # to check later on - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) constraint_data = self._get_constraint_data( - connection, table_name, schema, dblink, - info_cache=kw.get('info_cache')) + connection, + table_name, + schema, + dblink, + info_cache=kw.get("info_cache"), + ) def fkey_rec(): return { - 'name': None, - 'constrained_columns': [], - 'referred_schema': None, - 'referred_table': None, - 'referred_columns': [], - 'options': {}, + "name": None, + "constrained_columns": [], + "referred_schema": None, + "referred_table": None, + "referred_columns": [], + "options": {}, } fkeys = util.defaultdict(fkey_rec) for row in constraint_data: - (cons_name, cons_type, local_column, remote_table, remote_column, remote_owner) = \ - row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) + ( + cons_name, + cons_type, + local_column, + remote_table, + remote_column, + remote_owner, + ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) cons_name = self.normalize_name(cons_name) - if cons_type == 'R': + if cons_type == "R": if remote_table is None: # ticket 363 util.warn( - ("Got 'None' querying 'table_name' from " - "all_cons_columns%(dblink)s - does the user have " - "proper rights to the table?") % {'dblink': dblink}) + ( + "Got 'None' querying 'table_name' from " + "all_cons_columns%(dblink)s - does the user have " + "proper rights to the table?" + ) + % {"dblink": dblink} + ) continue rec = fkeys[cons_name] - rec['name'] = cons_name - local_cols, remote_cols = rec[ - 'constrained_columns'], rec['referred_columns'] + rec["name"] = cons_name + local_cols, remote_cols = ( + rec["constrained_columns"], + rec["referred_columns"], + ) - if not rec['referred_table']: + if not rec["referred_table"]: if resolve_synonyms: - ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = \ - self._resolve_synonym( - connection, - desired_owner=self.denormalize_name( - remote_owner), - desired_table=self.denormalize_name( - remote_table) - ) + ref_remote_name, ref_remote_owner, ref_dblink, ref_synonym = self._resolve_synonym( + connection, + desired_owner=self.denormalize_name(remote_owner), + desired_table=self.denormalize_name(remote_table), + ) if ref_synonym: remote_table = self.normalize_name(ref_synonym) remote_owner = self.normalize_name( - ref_remote_owner) + ref_remote_owner + ) - rec['referred_table'] = remote_table + rec["referred_table"] = remote_table - if requested_schema is not None or \ - self.denormalize_name(remote_owner) != schema: - rec['referred_schema'] = remote_owner + if ( + requested_schema is not None + or self.denormalize_name(remote_owner) != schema + ): + rec["referred_schema"] = remote_owner - if row[9] != 'NO ACTION': - rec['options']['ondelete'] = row[9] + if row[9] != "NO ACTION": + rec["options"]["ondelete"] = row[9] local_cols.append(local_column) remote_cols.append(remote_column) @@ -1701,54 +1855,82 @@ class OracleDialect(default.DefaultDialect): return list(fkeys.values()) @reflection.cache - def get_unique_constraints(self, connection, table_name, schema=None, **kw): - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) constraint_data = self._get_constraint_data( - connection, table_name, schema, dblink, - info_cache=kw.get('info_cache')) + connection, + table_name, + schema, + dblink, + info_cache=kw.get("info_cache"), + ) - unique_keys = filter(lambda x: x[1] == 'U', constraint_data) + unique_keys = filter(lambda x: x[1] == "U", constraint_data) uniques_group = groupby(unique_keys, lambda x: x[0]) - index_names = set([ix['name'] for ix in self.get_indexes(connection, table_name, schema=schema)]) + index_names = set( + [ + ix["name"] + for ix in self.get_indexes( + connection, table_name, schema=schema + ) + ] + ) return [ { - 'name': name, - 'column_names': cols, - 'duplicates_index': name if name in index_names else None + "name": name, + "column_names": cols, + "duplicates_index": name if name in index_names else None, } - for name, cols in - [ + for name, cols in [ [ self.normalize_name(i[0]), - [self.normalize_name(x[2]) for x in i[1]] - ] for i in uniques_group + [self.normalize_name(x[2]) for x in i[1]], + ] + for i in uniques_group ] ] @reflection.cache - def get_view_definition(self, connection, view_name, schema=None, - resolve_synonyms=False, dblink='', **kw): - info_cache = kw.get('info_cache') - (view_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, view_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) - - params = {'view_name': view_name} + def get_view_definition( + self, + connection, + view_name, + schema=None, + resolve_synonyms=False, + dblink="", + **kw + ): + info_cache = kw.get("info_cache") + (view_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + view_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) + + params = {"view_name": view_name} text = "SELECT text FROM all_views WHERE view_name=:view_name" if schema is not None: text += " AND owner = :schema" - params['schema'] = schema + params["schema"] = schema rp = connection.execute(sql.text(text), **params).scalar() if rp: @@ -1759,34 +1941,41 @@ class OracleDialect(default.DefaultDialect): return None @reflection.cache - def get_check_constraints(self, connection, table_name, schema=None, - include_all=False, **kw): - resolve_synonyms = kw.get('oracle_resolve_synonyms', False) - dblink = kw.get('dblink', '') - info_cache = kw.get('info_cache') - - (table_name, schema, dblink, synonym) = \ - self._prepare_reflection_args(connection, table_name, schema, - resolve_synonyms, dblink, - info_cache=info_cache) + def get_check_constraints( + self, connection, table_name, schema=None, include_all=False, **kw + ): + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) + dblink = kw.get("dblink", "") + info_cache = kw.get("info_cache") + + (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + connection, + table_name, + schema, + resolve_synonyms, + dblink, + info_cache=info_cache, + ) constraint_data = self._get_constraint_data( - connection, table_name, schema, dblink, - info_cache=kw.get('info_cache')) + connection, + table_name, + schema, + dblink, + info_cache=kw.get("info_cache"), + ) - check_constraints = filter(lambda x: x[1] == 'C', constraint_data) + check_constraints = filter(lambda x: x[1] == "C", constraint_data) return [ - { - 'name': self.normalize_name(cons[0]), - 'sqltext': cons[8], - } - for cons in check_constraints if include_all or - not re.match(r'..+?. IS NOT NULL$', cons[8])] + {"name": self.normalize_name(cons[0]), "sqltext": cons[8]} + for cons in check_constraints + if include_all or not re.match(r"..+?. IS NOT NULL$", cons[8]) + ] class _OuterJoinColumn(sql.ClauseElement): - __visit_name__ = 'outer_join_column' + __visit_name__ = "outer_join_column" def __init__(self, column): self.column = column diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index a00e7d95e..91534c0da 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -296,16 +296,13 @@ class _OracleInteger(sqltypes.Integer): def _cx_oracle_var(self, dialect, cursor): cx_Oracle = dialect.dbapi return cursor.var( - cx_Oracle.STRING, - 255, - arraysize=cursor.arraysize, - outconverter=int + cx_Oracle.STRING, 255, arraysize=cursor.arraysize, outconverter=int ) def _cx_oracle_outputtypehandler(self, dialect): - def handler(cursor, name, - default_type, size, precision, scale): + def handler(cursor, name, default_type, size, precision, scale): return self._cx_oracle_var(dialect, cursor) + return handler @@ -317,7 +314,8 @@ class _OracleNumeric(sqltypes.Numeric): return None elif self.asdecimal: processor = processors.to_decimal_processor_factory( - decimal.Decimal, self._effective_decimal_return_scale) + decimal.Decimal, self._effective_decimal_return_scale + ) def process(value): if isinstance(value, (int, float)): @@ -326,6 +324,7 @@ class _OracleNumeric(sqltypes.Numeric): return float(value) else: return value + return process else: return processors.to_float @@ -383,9 +382,10 @@ class _OracleNumeric(sqltypes.Numeric): type_ = cx_Oracle.NATIVE_FLOAT return cursor.var( - type_, 255, + type_, + 255, arraysize=cursor.arraysize, - outconverter=outconverter + outconverter=outconverter, ) return handler @@ -418,6 +418,7 @@ class _OracleDate(sqltypes.Date): return value.date() else: return value + return process @@ -467,6 +468,7 @@ class _OracleEnum(sqltypes.Enum): def process(value): raw_str = enum_proc(value) return raw_str + return process @@ -482,7 +484,8 @@ class _OracleBinary(sqltypes.LargeBinary): return None else: return super(_OracleBinary, self).result_processor( - dialect, coltype) + dialect, coltype + ) class _OracleInterval(oracle.INTERVAL): @@ -503,14 +506,18 @@ class OracleCompiler_cx_oracle(OracleCompiler): _oracle_cx_sql_compiler = True def bindparam_string(self, name, **kw): - quote = getattr(name, 'quote', None) - if quote is True or quote is not False and \ - self.preparer._bindparam_requires_quotes(name): - if kw.get('expanding', False): + quote = getattr(name, "quote", None) + if ( + quote is True + or quote is not False + and self.preparer._bindparam_requires_quotes(name) + ): + if kw.get("expanding", False): raise exc.CompileError( "Can't use expanding feature with parameter name " "%r on Oracle; it requires quoting which is not supported " - "in this context." % name) + "in this context." % name + ) quoted_name = '"%s"' % name self._quoted_bind_names[name] = quoted_name return OracleCompiler.bindparam_string(self, quoted_name, **kw) @@ -537,21 +544,22 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): if bindparam.isoutparam: name = self.compiled.bind_names[bindparam] type_impl = bindparam.type.dialect_impl(self.dialect) - if hasattr(type_impl, '_cx_oracle_var'): + if hasattr(type_impl, "_cx_oracle_var"): self.out_parameters[name] = type_impl._cx_oracle_var( - self.dialect, self.cursor) + self.dialect, self.cursor + ) else: dbtype = type_impl.get_dbapi_type(self.dialect.dbapi) if dbtype is None: raise exc.InvalidRequestError( "Cannot create out parameter for parameter " "%r - its type %r is not supported by" - " cx_oracle" % - (bindparam.key, bindparam.type) + " cx_oracle" % (bindparam.key, bindparam.type) ) self.out_parameters[name] = self.cursor.var(dbtype) - self.parameters[0][quoted_bind_names.get(name, name)] = \ - self.out_parameters[name] + self.parameters[0][ + quoted_bind_names.get(name, name) + ] = self.out_parameters[name] def _generate_cursor_outputtype_handler(self): output_handlers = {} @@ -559,8 +567,9 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): for (keyname, name, objects, type_) in self.compiled._result_columns: handler = type_._cached_custom_processor( self.dialect, - 'cx_oracle_outputtypehandler', - self._get_cx_oracle_type_handler) + "cx_oracle_outputtypehandler", + self._get_cx_oracle_type_handler, + ) if handler: denormalized_name = self.dialect.denormalize_name(keyname) @@ -569,16 +578,18 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): if output_handlers: default_handler = self._dbapi_connection.outputtypehandler - def output_type_handler(cursor, name, default_type, - size, precision, scale): + def output_type_handler( + cursor, name, default_type, size, precision, scale + ): if name in output_handlers: return output_handlers[name]( - cursor, name, - default_type, size, precision, scale) + cursor, name, default_type, size, precision, scale + ) else: return default_handler( cursor, name, default_type, size, precision, scale ) + self.cursor.outputtypehandler = output_type_handler def _get_cx_oracle_type_handler(self, impl): @@ -598,7 +609,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): self.set_input_sizes( self.compiled._quoted_bind_names, - include_types=self.dialect._include_setinputsizes + include_types=self.dialect._include_setinputsizes, ) self._handle_out_parameters() @@ -615,9 +626,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): def get_result_proxy(self): if self.out_parameters and self.compiled.returning: returning_params = [ - self.dialect._returningval( - self.out_parameters["ret_%d" % i] - ) + self.dialect._returningval(self.out_parameters["ret_%d" % i]) for i in range(len(self.out_parameters)) ] return ReturningResultProxy(self, returning_params) @@ -625,8 +634,10 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): result = _result.ResultProxy(self) if self.out_parameters: - if self.compiled_parameters is not None and \ - len(self.compiled_parameters) == 1: + if ( + self.compiled_parameters is not None + and len(self.compiled_parameters) == 1 + ): result.out_parameters = out_parameters = {} for bind, name in self.compiled.bind_names.items(): @@ -634,22 +645,24 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): type = bind.type impl_type = type.dialect_impl(self.dialect) dbapi_type = impl_type.get_dbapi_type( - self.dialect.dbapi) - result_processor = impl_type.\ - result_processor(self.dialect, - dbapi_type) + self.dialect.dbapi + ) + result_processor = impl_type.result_processor( + self.dialect, dbapi_type + ) if result_processor is not None: - out_parameters[name] = \ - result_processor( - self.dialect._paramval( - self.out_parameters[name] - )) + out_parameters[name] = result_processor( + self.dialect._paramval( + self.out_parameters[name] + ) + ) else: out_parameters[name] = self.dialect._paramval( - self.out_parameters[name]) + self.out_parameters[name] + ) else: result.out_parameters = dict( - (k, self._dialect._paramval(v)) + (k, self._dialect._paramval(v)) for k, v in self.out_parameters.items() ) @@ -667,14 +680,11 @@ class ReturningResultProxy(_result.FullyBufferedResultProxy): def _cursor_description(self): returning = self.context.compiled.returning return [ - (getattr(col, 'name', col.anon_label), None) - for col in returning + (getattr(col, "name", col.anon_label), None) for col in returning ] def _buffer_rows(self): - return collections.deque( - [tuple(self._returning_params)] - ) + return collections.deque([tuple(self._returning_params)]) class OracleDialect_cx_oracle(OracleDialect): @@ -696,7 +706,6 @@ class OracleDialect_cx_oracle(OracleDialect): oracle.BINARY_DOUBLE: _OracleBINARY_DOUBLE, sqltypes.Integer: _OracleInteger, oracle.NUMBER: _OracleNUMBER, - sqltypes.Date: _OracleDate, sqltypes.LargeBinary: _OracleBinary, sqltypes.Boolean: oracle._OracleBoolean, @@ -707,7 +716,6 @@ class OracleDialect_cx_oracle(OracleDialect): sqltypes.UnicodeText: _OracleUnicodeTextCLOB, sqltypes.CHAR: _OracleChar, sqltypes.Enum: _OracleEnum, - oracle.LONG: _OracleLong, oracle.RAW: _OracleRaw, sqltypes.Unicode: _OracleUnicodeStringCHAR, @@ -721,13 +729,15 @@ class OracleDialect_cx_oracle(OracleDialect): _cx_oracle_threaded = None - def __init__(self, - auto_convert_lobs=True, - coerce_to_unicode=True, - coerce_to_decimal=True, - arraysize=50, - threaded=None, - **kwargs): + def __init__( + self, + auto_convert_lobs=True, + coerce_to_unicode=True, + coerce_to_decimal=True, + arraysize=50, + threaded=None, + **kwargs + ): OracleDialect.__init__(self, **kwargs) self.arraysize = arraysize @@ -757,15 +767,23 @@ class OracleDialect_cx_oracle(OracleDialect): self.cx_oracle_ver = self._parse_cx_oracle_ver(cx_Oracle.version) if self.cx_oracle_ver < (5, 2) and self.cx_oracle_ver > (0, 0, 0): raise exc.InvalidRequestError( - "cx_Oracle version 5.2 and above are supported") + "cx_Oracle version 5.2 and above are supported" + ) self._has_native_int = hasattr(cx_Oracle, "NATIVE_INT") self._include_setinputsizes = { - cx_Oracle.NCLOB, cx_Oracle.CLOB, cx_Oracle.LOB, - cx_Oracle.NCHAR, cx_Oracle.FIXED_NCHAR, - cx_Oracle.BLOB, cx_Oracle.FIXED_CHAR, cx_Oracle.TIMESTAMP, - _OracleInteger, _OracleBINARY_FLOAT, _OracleBINARY_DOUBLE + cx_Oracle.NCLOB, + cx_Oracle.CLOB, + cx_Oracle.LOB, + cx_Oracle.NCHAR, + cx_Oracle.FIXED_NCHAR, + cx_Oracle.BLOB, + cx_Oracle.FIXED_CHAR, + cx_Oracle.TIMESTAMP, + _OracleInteger, + _OracleBINARY_FLOAT, + _OracleBINARY_DOUBLE, } self._paramval = lambda value: value.getvalue() @@ -786,18 +804,19 @@ class OracleDialect_cx_oracle(OracleDialect): else: self._returningval = self._paramval - self._is_cx_oracle_6 = self.cx_oracle_ver >= (6, ) + self._is_cx_oracle_6 = self.cx_oracle_ver >= (6,) def _pop_deprecated_kwargs(self, kwargs): - auto_setinputsizes = kwargs.pop('auto_setinputsizes', None) - exclude_setinputsizes = kwargs.pop('exclude_setinputsizes', None) + auto_setinputsizes = kwargs.pop("auto_setinputsizes", None) + exclude_setinputsizes = kwargs.pop("exclude_setinputsizes", None) if auto_setinputsizes or exclude_setinputsizes: util.warn_deprecated( "auto_setinputsizes and exclude_setinputsizes are deprecated. " "Modern cx_Oracle only requires that LOB types are part " "of this behavior, and these parameters no longer have any " - "effect.") - allow_twophase = kwargs.pop('allow_twophase', None) + "effect." + ) + allow_twophase = kwargs.pop("allow_twophase", None) if allow_twophase is not None: util.warn.deprecated( "allow_twophase is deprecated. The cx_Oracle dialect no " @@ -805,18 +824,16 @@ class OracleDialect_cx_oracle(OracleDialect): ) def _parse_cx_oracle_ver(self, version): - m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?', version) + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version) if m: - return tuple( - int(x) - for x in m.group(1, 2, 3) - if x is not None) + return tuple(int(x) for x in m.group(1, 2, 3) if x is not None) else: return (0, 0, 0) @classmethod def dbapi(cls): import cx_Oracle + return cx_Oracle def initialize(self, connection): @@ -835,15 +852,18 @@ class OracleDialect_cx_oracle(OracleDialect): self._decimal_char = connection.scalar( "select value from nls_session_parameters " - "where parameter = 'NLS_NUMERIC_CHARACTERS'")[0] - if self._decimal_char != '.': + "where parameter = 'NLS_NUMERIC_CHARACTERS'" + )[0] + if self._decimal_char != ".": _detect_decimal = self._detect_decimal _to_decimal = self._to_decimal self._detect_decimal = lambda value: _detect_decimal( - value.replace(self._decimal_char, ".")) + value.replace(self._decimal_char, ".") + ) self._to_decimal = lambda value: _to_decimal( - value.replace(self._decimal_char, ".")) + value.replace(self._decimal_char, ".") + ) def _detect_decimal(self, value): if "." in value: @@ -862,13 +882,16 @@ class OracleDialect_cx_oracle(OracleDialect): dialect = self cx_Oracle = dialect.dbapi - number_handler = _OracleNUMBER(asdecimal=True).\ - _cx_oracle_outputtypehandler(dialect) - float_handler = _OracleNUMBER(asdecimal=False).\ - _cx_oracle_outputtypehandler(dialect) + number_handler = _OracleNUMBER( + asdecimal=True + )._cx_oracle_outputtypehandler(dialect) + float_handler = _OracleNUMBER( + asdecimal=False + )._cx_oracle_outputtypehandler(dialect) - def output_type_handler(cursor, name, default_type, - size, precision, scale): + def output_type_handler( + cursor, name, default_type, size, precision, scale + ): if default_type == cx_Oracle.NUMBER: if not dialect.coerce_to_decimal: return None @@ -879,7 +902,8 @@ class OracleDialect_cx_oracle(OracleDialect): cx_Oracle.STRING, 255, outconverter=dialect._detect_decimal, - arraysize=cursor.arraysize) + arraysize=cursor.arraysize, + ) elif precision and scale > 0: return number_handler( cursor, name, default_type, size, precision, scale @@ -890,43 +914,55 @@ class OracleDialect_cx_oracle(OracleDialect): ) # allow all strings to come back natively as Unicode - elif dialect.coerce_to_unicode and \ - default_type in (cx_Oracle.STRING, cx_Oracle.FIXED_CHAR): + elif dialect.coerce_to_unicode and default_type in ( + cx_Oracle.STRING, + cx_Oracle.FIXED_CHAR, + ): if compat.py2k: outconverter = processors.to_unicode_processor_factory( - dialect.encoding, None) - return cursor.var( - cx_Oracle.STRING, size, cursor.arraysize, - outconverter=outconverter + dialect.encoding, None ) - else: return cursor.var( - util.text_type, size, cursor.arraysize + cx_Oracle.STRING, + size, + cursor.arraysize, + outconverter=outconverter, ) + else: + return cursor.var(util.text_type, size, cursor.arraysize) elif dialect.auto_convert_lobs and default_type in ( - cx_Oracle.CLOB, cx_Oracle.NCLOB + cx_Oracle.CLOB, + cx_Oracle.NCLOB, ): if compat.py2k: outconverter = processors.to_unicode_processor_factory( - dialect.encoding, None) + dialect.encoding, None + ) return cursor.var( - default_type, size, cursor.arraysize, - outconverter=lambda value: outconverter(value.read()) + default_type, + size, + cursor.arraysize, + outconverter=lambda value: outconverter(value.read()), ) else: return cursor.var( - default_type, size, cursor.arraysize, - outconverter=lambda value: value.read() + default_type, + size, + cursor.arraysize, + outconverter=lambda value: value.read(), ) elif dialect.auto_convert_lobs and default_type in ( - cx_Oracle.BLOB, + cx_Oracle.BLOB, ): return cursor.var( - default_type, size, cursor.arraysize, - outconverter=lambda value: value.read() + default_type, + size, + cursor.arraysize, + outconverter=lambda value: value.read(), ) + return output_type_handler def on_connect(self): @@ -941,16 +977,17 @@ class OracleDialect_cx_oracle(OracleDialect): def create_connect_args(self, url): opts = dict(url.query) - for opt in ('use_ansi', 'auto_convert_lobs'): + for opt in ("use_ansi", "auto_convert_lobs"): if opt in opts: util.warn_deprecated( "cx_oracle dialect option %r should only be passed to " - "create_engine directly, not within the URL string" % opt) + "create_engine directly, not within the URL string" % opt + ) util.coerce_kw_type(opts, opt, bool) setattr(self, opt, opts.pop(opt)) database = url.database - service_name = opts.pop('service_name', None) + service_name = opts.pop("service_name", None) if database or service_name: # if we have a database, then we have a remote host port = url.port @@ -962,11 +999,12 @@ class OracleDialect_cx_oracle(OracleDialect): if database and service_name: raise exc.InvalidRequestError( '"service_name" option shouldn\'t ' - 'be used with a "database" part of the url') + 'be used with a "database" part of the url' + ) if database: - makedsn_kwargs = {'sid': database} + makedsn_kwargs = {"sid": database} if service_name: - makedsn_kwargs = {'service_name': service_name} + makedsn_kwargs = {"service_name": service_name} dsn = self.dbapi.makedsn(url.host, port, **makedsn_kwargs) else: @@ -974,11 +1012,11 @@ class OracleDialect_cx_oracle(OracleDialect): dsn = url.host if dsn is not None: - opts['dsn'] = dsn + opts["dsn"] = dsn if url.password is not None: - opts['password'] = url.password + opts["password"] = url.password if url.username is not None: - opts['user'] = url.username + opts["user"] = url.username if self._cx_oracle_threaded is not None: opts.setdefault("threaded", self._cx_oracle_threaded) @@ -995,28 +1033,24 @@ class OracleDialect_cx_oracle(OracleDialect): else: return value - util.coerce_kw_type(opts, 'mode', convert_cx_oracle_constant) - util.coerce_kw_type(opts, 'threaded', bool) - util.coerce_kw_type(opts, 'events', bool) - util.coerce_kw_type(opts, 'purity', convert_cx_oracle_constant) + util.coerce_kw_type(opts, "mode", convert_cx_oracle_constant) + util.coerce_kw_type(opts, "threaded", bool) + util.coerce_kw_type(opts, "events", bool) + util.coerce_kw_type(opts, "purity", convert_cx_oracle_constant) return ([], opts) def _get_server_version_info(self, connection): - return tuple( - int(x) - for x in connection.connection.version.split('.') - ) + return tuple(int(x) for x in connection.connection.version.split(".")) def is_disconnect(self, e, connection, cursor): error, = e.args if isinstance( - e, - (self.dbapi.InterfaceError, self.dbapi.DatabaseError) + e, (self.dbapi.InterfaceError, self.dbapi.DatabaseError) ) and "not connected" in str(e): return True - if hasattr(error, 'code'): + if hasattr(error, "code"): # ORA-00028: your session has been killed # ORA-03114: not connected to ORACLE # ORA-03113: end-of-file on communication channel @@ -1052,22 +1086,25 @@ class OracleDialect_cx_oracle(OracleDialect): def do_prepare_twophase(self, connection, xid): result = connection.connection.prepare() - connection.info['cx_oracle_prepared'] = result + connection.info["cx_oracle_prepared"] = result - def do_rollback_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): self.do_rollback(connection.connection) - def do_commit_twophase(self, connection, xid, is_prepared=True, - recover=False): + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if not is_prepared: self.do_commit(connection.connection) else: - oci_prepared = connection.info['cx_oracle_prepared'] + oci_prepared = connection.info["cx_oracle_prepared"] if oci_prepared: self.do_commit(connection.connection) def do_recover_twophase(self, connection): - connection.info.pop('cx_oracle_prepared', None) + connection.info.pop("cx_oracle_prepared", None) + dialect = OracleDialect_cx_oracle diff --git a/lib/sqlalchemy/dialects/oracle/zxjdbc.py b/lib/sqlalchemy/dialects/oracle/zxjdbc.py index aa2562573..0a365f8b0 100644 --- a/lib/sqlalchemy/dialects/oracle/zxjdbc.py +++ b/lib/sqlalchemy/dialects/oracle/zxjdbc.py @@ -21,9 +21,11 @@ import re from sqlalchemy import sql, types as sqltypes, util from sqlalchemy.connectors.zxJDBC import ZxJDBCConnector -from sqlalchemy.dialects.oracle.base import (OracleCompiler, - OracleDialect, - OracleExecutionContext) +from sqlalchemy.dialects.oracle.base import ( + OracleCompiler, + OracleDialect, + OracleExecutionContext, +) from sqlalchemy.engine import result as _result from sqlalchemy.sql import expression import collections @@ -32,92 +34,100 @@ SQLException = zxJDBC = None class _ZxJDBCDate(sqltypes.Date): - def result_processor(self, dialect, coltype): def process(value): if value is None: return None else: return value.date() + return process class _ZxJDBCNumeric(sqltypes.Numeric): - def result_processor(self, dialect, coltype): # XXX: does the dialect return Decimal or not??? # if it does (in all cases), we could use a None processor as well as # the to_float generic processor if self.asdecimal: + def process(value): if isinstance(value, decimal.Decimal): return value else: return decimal.Decimal(str(value)) + else: + def process(value): if isinstance(value, decimal.Decimal): return float(value) else: return value + return process class OracleCompiler_zxjdbc(OracleCompiler): - def returning_clause(self, stmt, returning_cols): self.returning_cols = list( - expression._select_iterables(returning_cols)) + expression._select_iterables(returning_cols) + ) # within_columns_clause=False so that labels (foo AS bar) don't render - columns = [self.process(c, within_columns_clause=False) - for c in self.returning_cols] + columns = [ + self.process(c, within_columns_clause=False) + for c in self.returning_cols + ] - if not hasattr(self, 'returning_parameters'): + if not hasattr(self, "returning_parameters"): self.returning_parameters = [] binds = [] for i, col in enumerate(self.returning_cols): - dbtype = col.type.dialect_impl( - self.dialect).get_dbapi_type(self.dialect.dbapi) + dbtype = col.type.dialect_impl(self.dialect).get_dbapi_type( + self.dialect.dbapi + ) self.returning_parameters.append((i + 1, dbtype)) bindparam = sql.bindparam( - "ret_%d" % i, value=ReturningParam(dbtype)) + "ret_%d" % i, value=ReturningParam(dbtype) + ) self.binds[bindparam.key] = bindparam binds.append( - self.bindparam_string(self._truncate_bindparam(bindparam))) + self.bindparam_string(self._truncate_bindparam(bindparam)) + ) - return 'RETURNING ' + ', '.join(columns) + " INTO " + ", ".join(binds) + return "RETURNING " + ", ".join(columns) + " INTO " + ", ".join(binds) class OracleExecutionContext_zxjdbc(OracleExecutionContext): - def pre_exec(self): - if hasattr(self.compiled, 'returning_parameters'): + if hasattr(self.compiled, "returning_parameters"): # prepare a zxJDBC statement so we can grab its underlying # OraclePreparedStatement's getReturnResultSet later self.statement = self.cursor.prepare(self.statement) def get_result_proxy(self): - if hasattr(self.compiled, 'returning_parameters'): + if hasattr(self.compiled, "returning_parameters"): rrs = None try: try: rrs = self.statement.__statement__.getReturnResultSet() next(rrs) except SQLException as sqle: - msg = '%s [SQLCode: %d]' % ( - sqle.getMessage(), sqle.getErrorCode()) + msg = "%s [SQLCode: %d]" % ( + sqle.getMessage(), + sqle.getErrorCode(), + ) if sqle.getSQLState() is not None: - msg += ' [SQLState: %s]' % sqle.getSQLState() + msg += " [SQLState: %s]" % sqle.getSQLState() raise zxJDBC.Error(msg) else: row = tuple( - self.cursor.datahandler.getPyObject( - rrs, index, dbtype) - for index, dbtype in - self.compiled.returning_parameters) + self.cursor.datahandler.getPyObject(rrs, index, dbtype) + for index, dbtype in self.compiled.returning_parameters + ) return ReturningResultProxy(self, row) finally: if rrs is not None: @@ -146,7 +156,7 @@ class ReturningResultProxy(_result.FullyBufferedResultProxy): def _cursor_description(self): ret = [] for c in self.context.compiled.returning_cols: - if hasattr(c, 'name'): + if hasattr(c, "name"): ret.append((c.name, c.type)) else: ret.append((c.anon_label, c.type)) @@ -178,23 +188,24 @@ class ReturningParam(object): def __repr__(self): kls = self.__class__ - return '<%s.%s object at 0x%x type=%s>' % ( - kls.__module__, kls.__name__, id(self), self.type) + return "<%s.%s object at 0x%x type=%s>" % ( + kls.__module__, + kls.__name__, + id(self), + self.type, + ) class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect): - jdbc_db_name = 'oracle' - jdbc_driver_name = 'oracle.jdbc.OracleDriver' + jdbc_db_name = "oracle" + jdbc_driver_name = "oracle.jdbc.OracleDriver" statement_compiler = OracleCompiler_zxjdbc execution_ctx_cls = OracleExecutionContext_zxjdbc colspecs = util.update_copy( OracleDialect.colspecs, - { - sqltypes.Date: _ZxJDBCDate, - sqltypes.Numeric: _ZxJDBCNumeric - } + {sqltypes.Date: _ZxJDBCDate, sqltypes.Numeric: _ZxJDBCNumeric}, ) def __init__(self, *args, **kwargs): @@ -212,24 +223,31 @@ class OracleDialect_zxjdbc(ZxJDBCConnector, OracleDialect): statement.registerReturnParameter(index, object.type) elif dbtype is None: OracleDataHandler.setJDBCObject( - self, statement, index, object) + self, statement, index, object + ) else: OracleDataHandler.setJDBCObject( - self, statement, index, object, dbtype) + self, statement, index, object, dbtype + ) + self.DataHandler = OracleReturningDataHandler def initialize(self, connection): super(OracleDialect_zxjdbc, self).initialize(connection) - self.implicit_returning = \ - connection.connection.driverversion >= '10.2' + self.implicit_returning = connection.connection.driverversion >= "10.2" def _create_jdbc_url(self, url): - return 'jdbc:oracle:thin:@%s:%s:%s' % ( - url.host, url.port or 1521, url.database) + return "jdbc:oracle:thin:@%s:%s:%s" % ( + url.host, + url.port or 1521, + url.database, + ) def _get_server_version_info(self, connection): version = re.search( - r'Release ([\d\.]+)', connection.connection.dbversion).group(1) - return tuple(int(x) for x in version.split('.')) + r"Release ([\d\.]+)", connection.connection.dbversion + ).group(1) + return tuple(int(x) for x in version.split(".")) + dialect = OracleDialect_zxjdbc diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index 84f720028..9e65484fa 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -5,33 +5,110 @@ # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php -from . import base, psycopg2, pg8000, pypostgresql, pygresql, \ - zxjdbc, psycopg2cffi # noqa +from . import ( + base, + psycopg2, + pg8000, + pypostgresql, + pygresql, + zxjdbc, + psycopg2cffi, +) # noqa -from .base import \ - INTEGER, BIGINT, SMALLINT, VARCHAR, CHAR, TEXT, NUMERIC, FLOAT, REAL, \ - INET, CIDR, UUID, BIT, MACADDR, MONEY, OID, REGCLASS, DOUBLE_PRECISION, \ - TIMESTAMP, TIME, DATE, BYTEA, BOOLEAN, INTERVAL, ENUM, TSVECTOR, \ - DropEnumType, CreateEnumType +from .base import ( + INTEGER, + BIGINT, + SMALLINT, + VARCHAR, + CHAR, + TEXT, + NUMERIC, + FLOAT, + REAL, + INET, + CIDR, + UUID, + BIT, + MACADDR, + MONEY, + OID, + REGCLASS, + DOUBLE_PRECISION, + TIMESTAMP, + TIME, + DATE, + BYTEA, + BOOLEAN, + INTERVAL, + ENUM, + TSVECTOR, + DropEnumType, + CreateEnumType, +) from .hstore import HSTORE, hstore from .json import JSON, JSONB from .array import array, ARRAY, Any, All from .ext import aggregate_order_by, ExcludeConstraint, array_agg from .dml import insert, Insert -from .ranges import INT4RANGE, INT8RANGE, NUMRANGE, DATERANGE, TSRANGE, \ - TSTZRANGE +from .ranges import ( + INT4RANGE, + INT8RANGE, + NUMRANGE, + DATERANGE, + TSRANGE, + TSTZRANGE, +) base.dialect = dialect = psycopg2.dialect __all__ = ( - 'INTEGER', 'BIGINT', 'SMALLINT', 'VARCHAR', 'CHAR', 'TEXT', 'NUMERIC', - 'FLOAT', 'REAL', 'INET', 'CIDR', 'UUID', 'BIT', 'MACADDR', 'MONEY', 'OID', - 'REGCLASS', 'DOUBLE_PRECISION', 'TIMESTAMP', 'TIME', 'DATE', 'BYTEA', - 'BOOLEAN', 'INTERVAL', 'ARRAY', 'ENUM', 'dialect', 'array', 'HSTORE', - 'hstore', 'INT4RANGE', 'INT8RANGE', 'NUMRANGE', 'DATERANGE', - 'TSRANGE', 'TSTZRANGE', 'JSON', 'JSONB', 'Any', 'All', - 'DropEnumType', 'CreateEnumType', 'ExcludeConstraint', - 'aggregate_order_by', 'array_agg', 'insert', 'Insert' + "INTEGER", + "BIGINT", + "SMALLINT", + "VARCHAR", + "CHAR", + "TEXT", + "NUMERIC", + "FLOAT", + "REAL", + "INET", + "CIDR", + "UUID", + "BIT", + "MACADDR", + "MONEY", + "OID", + "REGCLASS", + "DOUBLE_PRECISION", + "TIMESTAMP", + "TIME", + "DATE", + "BYTEA", + "BOOLEAN", + "INTERVAL", + "ARRAY", + "ENUM", + "dialect", + "array", + "HSTORE", + "hstore", + "INT4RANGE", + "INT8RANGE", + "NUMRANGE", + "DATERANGE", + "TSRANGE", + "TSTZRANGE", + "JSON", + "JSONB", + "Any", + "All", + "DropEnumType", + "CreateEnumType", + "ExcludeConstraint", + "aggregate_order_by", + "array_agg", + "insert", + "Insert", ) diff --git a/lib/sqlalchemy/dialects/postgresql/array.py b/lib/sqlalchemy/dialects/postgresql/array.py index b2674046e..07167f9d0 100644 --- a/lib/sqlalchemy/dialects/postgresql/array.py +++ b/lib/sqlalchemy/dialects/postgresql/array.py @@ -78,7 +78,8 @@ class array(expression.Tuple): :class:`.postgresql.ARRAY` """ - __visit_name__ = 'array' + + __visit_name__ = "array" def __init__(self, clauses, **kw): super(array, self).__init__(*clauses, **kw) @@ -90,18 +91,26 @@ class array(expression.Tuple): # a Slice object from that assert isinstance(obj, int) return expression.BindParameter( - None, obj, _compared_to_operator=operator, + None, + obj, + _compared_to_operator=operator, type_=type_, - _compared_to_type=self.type, unique=True) + _compared_to_type=self.type, + unique=True, + ) else: - return array([ - self._bind_param(operator, o, _assume_scalar=True, type_=type_) - for o in obj]) + return array( + [ + self._bind_param( + operator, o, _assume_scalar=True, type_=type_ + ) + for o in obj + ] + ) def self_group(self, against=None): - if (against in ( - operators.any_op, operators.all_op, operators.getitem)): + if against in (operators.any_op, operators.all_op, operators.getitem): return expression.Grouping(self) else: return self @@ -180,7 +189,8 @@ class ARRAY(sqltypes.ARRAY): elements of the argument array expression. """ return self.operate( - CONTAINED_BY, other, result_type=sqltypes.Boolean) + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) def overlap(self, other): """Boolean expression. Test if array has elements in common with @@ -190,8 +200,9 @@ class ARRAY(sqltypes.ARRAY): comparator_factory = Comparator - def __init__(self, item_type, as_tuple=False, dimensions=None, - zero_indexes=False): + def __init__( + self, item_type, as_tuple=False, dimensions=None, zero_indexes=False + ): """Construct an ARRAY. E.g.:: @@ -228,8 +239,10 @@ class ARRAY(sqltypes.ARRAY): """ if isinstance(item_type, ARRAY): - raise ValueError("Do not nest ARRAY types; ARRAY(basetype) " - "handles multi-dimensional arrays of basetype") + raise ValueError( + "Do not nest ARRAY types; ARRAY(basetype) " + "handles multi-dimensional arrays of basetype" + ) if isinstance(item_type, type): item_type = item_type() self.item_type = item_type @@ -251,11 +264,17 @@ class ARRAY(sqltypes.ARRAY): def _proc_array(self, arr, itemproc, dim, collection): if dim is None: arr = list(arr) - if dim == 1 or dim is None and ( + if ( + dim == 1 + or dim is None + and ( # this has to be (list, tuple), or at least # not hasattr('__iter__'), since Py3K strings # etc. have __iter__ - not arr or not isinstance(arr[0], (list, tuple))): + not arr + or not isinstance(arr[0], (list, tuple)) + ) + ): if itemproc: return collection(itemproc(x) for x in arr) else: @@ -263,30 +282,33 @@ class ARRAY(sqltypes.ARRAY): else: return collection( self._proc_array( - x, itemproc, + x, + itemproc, dim - 1 if dim is not None else None, - collection) + collection, + ) for x in arr ) def bind_processor(self, dialect): - item_proc = self.item_type.dialect_impl(dialect).\ - bind_processor(dialect) + item_proc = self.item_type.dialect_impl(dialect).bind_processor( + dialect + ) def process(value): if value is None: return value else: return self._proc_array( - value, - item_proc, - self.dimensions, - list) + value, item_proc, self.dimensions, list + ) + return process def result_processor(self, dialect, coltype): - item_proc = self.item_type.dialect_impl(dialect).\ - result_processor(dialect, coltype) + item_proc = self.item_type.dialect_impl(dialect).result_processor( + dialect, coltype + ) def process(value): if value is None: @@ -296,8 +318,11 @@ class ARRAY(sqltypes.ARRAY): value, item_proc, self.dimensions, - tuple if self.as_tuple else list) + tuple if self.as_tuple else list, + ) + return process + colspecs[sqltypes.ARRAY] = ARRAY -ischema_names['_array'] = ARRAY +ischema_names["_array"] = ARRAY diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index d68ab8ef5..11833da57 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -930,57 +930,164 @@ try: except ImportError: _python_UUID = None -from sqlalchemy.types import INTEGER, BIGINT, SMALLINT, VARCHAR, \ - CHAR, TEXT, FLOAT, NUMERIC, \ - DATE, BOOLEAN, REAL +from sqlalchemy.types import ( + INTEGER, + BIGINT, + SMALLINT, + VARCHAR, + CHAR, + TEXT, + FLOAT, + NUMERIC, + DATE, + BOOLEAN, + REAL, +) AUTOCOMMIT_REGEXP = re.compile( - r'\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|GRANT|REVOKE|' - 'IMPORT FOREIGN SCHEMA|REFRESH MATERIALIZED VIEW|TRUNCATE)', - re.I | re.UNICODE) + r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|GRANT|REVOKE|" + "IMPORT FOREIGN SCHEMA|REFRESH MATERIALIZED VIEW|TRUNCATE)", + re.I | re.UNICODE, +) RESERVED_WORDS = set( - ["all", "analyse", "analyze", "and", "any", "array", "as", "asc", - "asymmetric", "both", "case", "cast", "check", "collate", "column", - "constraint", "create", "current_catalog", "current_date", - "current_role", "current_time", "current_timestamp", "current_user", - "default", "deferrable", "desc", "distinct", "do", "else", "end", - "except", "false", "fetch", "for", "foreign", "from", "grant", "group", - "having", "in", "initially", "intersect", "into", "leading", "limit", - "localtime", "localtimestamp", "new", "not", "null", "of", "off", - "offset", "old", "on", "only", "or", "order", "placing", "primary", - "references", "returning", "select", "session_user", "some", "symmetric", - "table", "then", "to", "trailing", "true", "union", "unique", "user", - "using", "variadic", "when", "where", "window", "with", "authorization", - "between", "binary", "cross", "current_schema", "freeze", "full", - "ilike", "inner", "is", "isnull", "join", "left", "like", "natural", - "notnull", "outer", "over", "overlaps", "right", "similar", "verbose" - ]) + [ + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "both", + "case", + "cast", + "check", + "collate", + "column", + "constraint", + "create", + "current_catalog", + "current_date", + "current_role", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "fetch", + "for", + "foreign", + "from", + "grant", + "group", + "having", + "in", + "initially", + "intersect", + "into", + "leading", + "limit", + "localtime", + "localtimestamp", + "new", + "not", + "null", + "of", + "off", + "offset", + "old", + "on", + "only", + "or", + "order", + "placing", + "primary", + "references", + "returning", + "select", + "session_user", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "variadic", + "when", + "where", + "window", + "with", + "authorization", + "between", + "binary", + "cross", + "current_schema", + "freeze", + "full", + "ilike", + "inner", + "is", + "isnull", + "join", + "left", + "like", + "natural", + "notnull", + "outer", + "over", + "overlaps", + "right", + "similar", + "verbose", + ] +) _DECIMAL_TYPES = (1231, 1700) _FLOAT_TYPES = (700, 701, 1021, 1022) _INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) + class BYTEA(sqltypes.LargeBinary): - __visit_name__ = 'BYTEA' + __visit_name__ = "BYTEA" class DOUBLE_PRECISION(sqltypes.Float): - __visit_name__ = 'DOUBLE_PRECISION' + __visit_name__ = "DOUBLE_PRECISION" class INET(sqltypes.TypeEngine): __visit_name__ = "INET" + + PGInet = INET class CIDR(sqltypes.TypeEngine): __visit_name__ = "CIDR" + + PGCidr = CIDR class MACADDR(sqltypes.TypeEngine): __visit_name__ = "MACADDR" + + PGMacAddr = MACADDR @@ -991,6 +1098,7 @@ class MONEY(sqltypes.TypeEngine): .. versionadded:: 1.2 """ + __visit_name__ = "MONEY" @@ -1001,6 +1109,7 @@ class OID(sqltypes.TypeEngine): .. versionadded:: 0.9.5 """ + __visit_name__ = "OID" @@ -1011,18 +1120,17 @@ class REGCLASS(sqltypes.TypeEngine): .. versionadded:: 1.2.7 """ + __visit_name__ = "REGCLASS" class TIMESTAMP(sqltypes.TIMESTAMP): - def __init__(self, timezone=False, precision=None): super(TIMESTAMP, self).__init__(timezone=timezone) self.precision = precision class TIME(sqltypes.TIME): - def __init__(self, timezone=False, precision=None): super(TIME, self).__init__(timezone=timezone) self.precision = precision @@ -1036,7 +1144,8 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): It is known to work on psycopg2 and not pg8000 or zxjdbc. """ - __visit_name__ = 'INTERVAL' + + __visit_name__ = "INTERVAL" native = True def __init__(self, precision=None, fields=None): @@ -1065,11 +1174,12 @@ class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): def python_type(self): return dt.timedelta + PGInterval = INTERVAL class BIT(sqltypes.TypeEngine): - __visit_name__ = 'BIT' + __visit_name__ = "BIT" def __init__(self, length=None, varying=False): if not varying: @@ -1080,6 +1190,7 @@ class BIT(sqltypes.TypeEngine): self.length = length self.varying = varying + PGBit = BIT @@ -1095,7 +1206,8 @@ class UUID(sqltypes.TypeEngine): It is known to work on psycopg2 and not pg8000. """ - __visit_name__ = 'UUID' + + __visit_name__ = "UUID" def __init__(self, as_uuid=False): """Construct a UUID type. @@ -1115,24 +1227,29 @@ class UUID(sqltypes.TypeEngine): def bind_processor(self, dialect): if self.as_uuid: + def process(value): if value is not None: value = util.text_type(value) return value + return process else: return None def result_processor(self, dialect, coltype): if self.as_uuid: + def process(value): if value is not None: value = _python_UUID(value) return value + return process else: return None + PGUuid = UUID @@ -1151,7 +1268,8 @@ class TSVECTOR(sqltypes.TypeEngine): :ref:`postgresql_match` """ - __visit_name__ = 'TSVECTOR' + + __visit_name__ = "TSVECTOR" class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): @@ -1273,12 +1391,12 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): """ kw.setdefault("validate_strings", impl.validate_strings) - kw.setdefault('name', impl.name) - kw.setdefault('schema', impl.schema) - kw.setdefault('inherit_schema', impl.inherit_schema) - kw.setdefault('metadata', impl.metadata) - kw.setdefault('_create_events', False) - kw.setdefault('values_callable', impl.values_callable) + kw.setdefault("name", impl.name) + kw.setdefault("schema", impl.schema) + kw.setdefault("inherit_schema", impl.inherit_schema) + kw.setdefault("metadata", impl.metadata) + kw.setdefault("_create_events", False) + kw.setdefault("values_callable", impl.values_callable) return cls(**kw) def create(self, bind=None, checkfirst=True): @@ -1300,9 +1418,9 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): if not bind.dialect.supports_native_enum: return - if not checkfirst or \ - not bind.dialect.has_type( - bind, self.name, schema=self.schema): + if not checkfirst or not bind.dialect.has_type( + bind, self.name, schema=self.schema + ): bind.execute(CreateEnumType(self)) def drop(self, bind=None, checkfirst=True): @@ -1323,8 +1441,9 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): if not bind.dialect.supports_native_enum: return - if not checkfirst or \ - bind.dialect.has_type(bind, self.name, schema=self.schema): + if not checkfirst or bind.dialect.has_type( + bind, self.name, schema=self.schema + ): bind.execute(DropEnumType(self)) def _check_for_name_in_memos(self, checkfirst, kw): @@ -1338,12 +1457,12 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): """ if not self.create_type: return True - if '_ddl_runner' in kw: - ddl_runner = kw['_ddl_runner'] - if '_pg_enums' in ddl_runner.memo: - pg_enums = ddl_runner.memo['_pg_enums'] + if "_ddl_runner" in kw: + ddl_runner = kw["_ddl_runner"] + if "_pg_enums" in ddl_runner.memo: + pg_enums = ddl_runner.memo["_pg_enums"] else: - pg_enums = ddl_runner.memo['_pg_enums'] = set() + pg_enums = ddl_runner.memo["_pg_enums"] = set() present = (self.schema, self.name) in pg_enums pg_enums.add((self.schema, self.name)) return present @@ -1351,16 +1470,22 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): return False def _on_table_create(self, target, bind, checkfirst=False, **kw): - if checkfirst or ( - not self.metadata and - not kw.get('_is_metadata_operation', False)) and \ - not self._check_for_name_in_memos(checkfirst, kw): + if ( + checkfirst + or ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + ) + and not self._check_for_name_in_memos(checkfirst, kw) + ): self.create(bind=bind, checkfirst=checkfirst) def _on_table_drop(self, target, bind, checkfirst=False, **kw): - if not self.metadata and \ - not kw.get('_is_metadata_operation', False) and \ - not self._check_for_name_in_memos(checkfirst, kw): + if ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + and not self._check_for_name_in_memos(checkfirst, kw) + ): self.drop(bind=bind, checkfirst=checkfirst) def _on_metadata_create(self, target, bind, checkfirst=False, **kw): @@ -1371,49 +1496,46 @@ class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): if not self._check_for_name_in_memos(checkfirst, kw): self.drop(bind=bind, checkfirst=checkfirst) -colspecs = { - sqltypes.Interval: INTERVAL, - sqltypes.Enum: ENUM, -} + +colspecs = {sqltypes.Interval: INTERVAL, sqltypes.Enum: ENUM} ischema_names = { - 'integer': INTEGER, - 'bigint': BIGINT, - 'smallint': SMALLINT, - 'character varying': VARCHAR, - 'character': CHAR, + "integer": INTEGER, + "bigint": BIGINT, + "smallint": SMALLINT, + "character varying": VARCHAR, + "character": CHAR, '"char"': sqltypes.String, - 'name': sqltypes.String, - 'text': TEXT, - 'numeric': NUMERIC, - 'float': FLOAT, - 'real': REAL, - 'inet': INET, - 'cidr': CIDR, - 'uuid': UUID, - 'bit': BIT, - 'bit varying': BIT, - 'macaddr': MACADDR, - 'money': MONEY, - 'oid': OID, - 'regclass': REGCLASS, - 'double precision': DOUBLE_PRECISION, - 'timestamp': TIMESTAMP, - 'timestamp with time zone': TIMESTAMP, - 'timestamp without time zone': TIMESTAMP, - 'time with time zone': TIME, - 'time without time zone': TIME, - 'date': DATE, - 'time': TIME, - 'bytea': BYTEA, - 'boolean': BOOLEAN, - 'interval': INTERVAL, - 'tsvector': TSVECTOR + "name": sqltypes.String, + "text": TEXT, + "numeric": NUMERIC, + "float": FLOAT, + "real": REAL, + "inet": INET, + "cidr": CIDR, + "uuid": UUID, + "bit": BIT, + "bit varying": BIT, + "macaddr": MACADDR, + "money": MONEY, + "oid": OID, + "regclass": REGCLASS, + "double precision": DOUBLE_PRECISION, + "timestamp": TIMESTAMP, + "timestamp with time zone": TIMESTAMP, + "timestamp without time zone": TIMESTAMP, + "time with time zone": TIME, + "time without time zone": TIME, + "date": DATE, + "time": TIME, + "bytea": BYTEA, + "boolean": BOOLEAN, + "interval": INTERVAL, + "tsvector": TSVECTOR, } class PGCompiler(compiler.SQLCompiler): - def visit_array(self, element, **kw): return "ARRAY[%s]" % self.visit_clauselist(element, **kw) @@ -1424,77 +1546,75 @@ class PGCompiler(compiler.SQLCompiler): ) def visit_json_getitem_op_binary(self, binary, operator, **kw): - kw['eager_grouping'] = True - return self._generate_generic_binary( - binary, " -> ", **kw - ) + kw["eager_grouping"] = True + return self._generate_generic_binary(binary, " -> ", **kw) def visit_json_path_getitem_op_binary(self, binary, operator, **kw): - kw['eager_grouping'] = True - return self._generate_generic_binary( - binary, " #> ", **kw - ) + kw["eager_grouping"] = True + return self._generate_generic_binary(binary, " #> ", **kw) def visit_getitem_binary(self, binary, operator, **kw): return "%s[%s]" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) def visit_aggregate_order_by(self, element, **kw): return "%s ORDER BY %s" % ( self.process(element.target, **kw), - self.process(element.order_by, **kw) + self.process(element.order_by, **kw), ) def visit_match_op_binary(self, binary, operator, **kw): if "postgresql_regconfig" in binary.modifiers: regconfig = self.render_literal_value( - binary.modifiers['postgresql_regconfig'], - sqltypes.STRINGTYPE) + binary.modifiers["postgresql_regconfig"], sqltypes.STRINGTYPE + ) if regconfig: return "%s @@ to_tsquery(%s, %s)" % ( self.process(binary.left, **kw), regconfig, - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) return "%s @@ to_tsquery(%s)" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw) + self.process(binary.right, **kw), ) def visit_ilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return '%s ILIKE %s' % \ - (self.process(binary.left, **kw), - self.process(binary.right, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + return "%s ILIKE %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_notilike_op_binary(self, binary, operator, **kw): escape = binary.modifiers.get("escape", None) - return '%s NOT ILIKE %s' % \ - (self.process(binary.left, **kw), - self.process(binary.right, **kw)) \ - + ( - ' ESCAPE ' + - self.render_literal_value(escape, sqltypes.STRINGTYPE) - if escape else '' - ) + return "%s NOT ILIKE %s" % ( + self.process(binary.left, **kw), + self.process(binary.right, **kw), + ) + ( + " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE) + if escape + else "" + ) def visit_empty_set_expr(self, element_types): # cast the empty set to the type we are comparing against. if # we are comparing against the null type, pick an arbitrary # datatype for the empty set - return 'SELECT %s WHERE 1!=1' % ( + return "SELECT %s WHERE 1!=1" % ( ", ".join( - "CAST(NULL AS %s)" % self.dialect.type_compiler.process( - INTEGER() if type_._isnull else type_, - ) for type_ in element_types or [INTEGER()] + "CAST(NULL AS %s)" + % self.dialect.type_compiler.process( + INTEGER() if type_._isnull else type_ + ) + for type_ in element_types or [INTEGER()] ), ) @@ -1502,7 +1622,7 @@ class PGCompiler(compiler.SQLCompiler): value = super(PGCompiler, self).render_literal_value(value, type_) if self.dialect._backslash_escapes: - value = value.replace('\\', '\\\\') + value = value.replace("\\", "\\\\") return value def visit_sequence(self, seq, **kw): @@ -1519,7 +1639,7 @@ class PGCompiler(compiler.SQLCompiler): return text def format_from_hint_text(self, sqltext, table, hint, iscrud): - if hint.upper() != 'ONLY': + if hint.upper() != "ONLY": raise exc.CompileError("Unrecognized hint: %r" % hint) return "ONLY " + sqltext @@ -1528,12 +1648,19 @@ class PGCompiler(compiler.SQLCompiler): if select._distinct is True: return "DISTINCT " elif isinstance(select._distinct, (list, tuple)): - return "DISTINCT ON (" + ', '.join( - [self.process(col, **kw) for col in select._distinct] - ) + ") " + return ( + "DISTINCT ON (" + + ", ".join( + [self.process(col, **kw) for col in select._distinct] + ) + + ") " + ) else: - return "DISTINCT ON (" + \ - self.process(select._distinct, **kw) + ") " + return ( + "DISTINCT ON (" + + self.process(select._distinct, **kw) + + ") " + ) else: return "" @@ -1551,8 +1678,9 @@ class PGCompiler(compiler.SQLCompiler): if select._for_update_arg.of: tables = util.OrderedSet( - c.table if isinstance(c, expression.ColumnClause) - else c for c in select._for_update_arg.of) + c.table if isinstance(c, expression.ColumnClause) else c + for c in select._for_update_arg.of + ) tmp += " OF " + ", ".join( self.process(table, ashint=True, use_schema=False, **kw) for table in tables @@ -1572,7 +1700,7 @@ class PGCompiler(compiler.SQLCompiler): for c in expression._select_iterables(returning_cols) ] - return 'RETURNING ' + ', '.join(columns) + return "RETURNING " + ", ".join(columns) def visit_substring_func(self, func, **kw): s = self.process(func.clauses.clauses[0], **kw) @@ -1586,24 +1714,24 @@ class PGCompiler(compiler.SQLCompiler): def _on_conflict_target(self, clause, **kw): if clause.constraint_target is not None: - target_text = 'ON CONSTRAINT %s' % clause.constraint_target + target_text = "ON CONSTRAINT %s" % clause.constraint_target elif clause.inferred_target_elements is not None: - target_text = '(%s)' % ', '.join( - (self.preparer.quote(c) - if isinstance(c, util.string_types) - else - self.process(c, include_table=False, use_schema=False)) + target_text = "(%s)" % ", ".join( + ( + self.preparer.quote(c) + if isinstance(c, util.string_types) + else self.process(c, include_table=False, use_schema=False) + ) for c in clause.inferred_target_elements ) if clause.inferred_target_whereclause is not None: - target_text += ' WHERE %s' % \ - self.process( - clause.inferred_target_whereclause, - include_table=False, - use_schema=False - ) + target_text += " WHERE %s" % self.process( + clause.inferred_target_whereclause, + include_table=False, + use_schema=False, + ) else: - target_text = '' + target_text = "" return target_text @@ -1627,36 +1755,35 @@ class PGCompiler(compiler.SQLCompiler): set_parameters = dict(clause.update_values_to_set) # create a list of column assignment clauses as tuples - insert_statement = self.stack[-1]['selectable'] + insert_statement = self.stack[-1]["selectable"] cols = insert_statement.table.c for c in cols: col_key = c.key if col_key in set_parameters: value = set_parameters.pop(col_key) if elements._is_literal(value): - value = elements.BindParameter( - None, value, type_=c.type - ) + value = elements.BindParameter(None, value, type_=c.type) else: - if isinstance(value, elements.BindParameter) and \ - value.type._isnull: + if ( + isinstance(value, elements.BindParameter) + and value.type._isnull + ): value = value._clone() value.type = c.type value_text = self.process(value.self_group(), use_schema=False) - key_text = ( - self.preparer.quote(col_key) - ) - action_set_ops.append('%s = %s' % (key_text, value_text)) + key_text = self.preparer.quote(col_key) + action_set_ops.append("%s = %s" % (key_text, value_text)) # check for names that don't match columns if set_parameters: util.warn( "Additional column names not matching " - "any column keys in table '%s': %s" % ( + "any column keys in table '%s': %s" + % ( self.statement.table.name, - (", ".join("'%s'" % c for c in set_parameters)) + (", ".join("'%s'" % c for c in set_parameters)), ) ) for k, v in set_parameters.items(): @@ -1666,42 +1793,37 @@ class PGCompiler(compiler.SQLCompiler): else self.process(k, use_schema=False) ) value_text = self.process( - elements._literal_as_binds(v), - use_schema=False + elements._literal_as_binds(v), use_schema=False ) - action_set_ops.append('%s = %s' % (key_text, value_text)) + action_set_ops.append("%s = %s" % (key_text, value_text)) - action_text = ', '.join(action_set_ops) + action_text = ", ".join(action_set_ops) if clause.update_whereclause is not None: - action_text += ' WHERE %s' % \ - self.process( - clause.update_whereclause, - include_table=True, - use_schema=False - ) + action_text += " WHERE %s" % self.process( + clause.update_whereclause, include_table=True, use_schema=False + ) - return 'ON CONFLICT %s DO UPDATE SET %s' % (target_text, action_text) + return "ON CONFLICT %s DO UPDATE SET %s" % (target_text, action_text) - def update_from_clause(self, update_stmt, - from_table, extra_froms, - from_hints, - **kw): - return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in extra_froms) + def update_from_clause( + self, update_stmt, from_table, extra_froms, from_hints, **kw + ): + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in extra_froms + ) - def delete_extra_from_clause(self, delete_stmt, from_table, - extra_froms, from_hints, **kw): + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the DELETE .. USING clause specific to PostgreSQL.""" - return "USING " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in extra_froms) + return "USING " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in extra_froms + ) class PGDDLCompiler(compiler.DDLCompiler): - def get_column_specification(self, column, **kwargs): colspec = self.preparer.format_column(column) @@ -1709,17 +1831,21 @@ class PGDDLCompiler(compiler.DDLCompiler): if isinstance(impl_type, sqltypes.TypeDecorator): impl_type = impl_type.impl - if column.primary_key and \ - column is column.table._autoincrement_column and \ - ( - self.dialect.supports_smallserial or - not isinstance(impl_type, sqltypes.SmallInteger) - ) and ( - column.default is None or - ( - isinstance(column.default, schema.Sequence) and - column.default.optional - )): + if ( + column.primary_key + and column is column.table._autoincrement_column + and ( + self.dialect.supports_smallserial + or not isinstance(impl_type, sqltypes.SmallInteger) + ) + and ( + column.default is None + or ( + isinstance(column.default, schema.Sequence) + and column.default.optional + ) + ) + ): if isinstance(impl_type, sqltypes.BigInteger): colspec += " BIGSERIAL" elif isinstance(impl_type, sqltypes.SmallInteger): @@ -1728,7 +1854,8 @@ class PGDDLCompiler(compiler.DDLCompiler): colspec += " SERIAL" else: colspec += " " + self.dialect.type_compiler.process( - column.type, type_expression=column) + column.type, type_expression=column + ) default = self.get_column_default_string(column) if default is not None: colspec += " DEFAULT " + default @@ -1744,15 +1871,14 @@ class PGDDLCompiler(compiler.DDLCompiler): self.preparer.format_type(type_), ", ".join( self.sql_compiler.process(sql.literal(e), literal_binds=True) - for e in type_.enums) + for e in type_.enums + ), ) def visit_drop_enum_type(self, drop): type_ = drop.element - return "DROP TYPE %s" % ( - self.preparer.format_type(type_) - ) + return "DROP TYPE %s" % (self.preparer.format_type(type_)) def visit_create_index(self, create): preparer = self.preparer @@ -1764,46 +1890,53 @@ class PGDDLCompiler(compiler.DDLCompiler): text += "INDEX " if self.dialect._supports_create_index_concurrently: - concurrently = index.dialect_options['postgresql']['concurrently'] + concurrently = index.dialect_options["postgresql"]["concurrently"] if concurrently: text += "CONCURRENTLY " text += "%s ON %s " % ( - self._prepared_index_name(index, - include_schema=False), - preparer.format_table(index.table) + self._prepared_index_name(index, include_schema=False), + preparer.format_table(index.table), ) - using = index.dialect_options['postgresql']['using'] + using = index.dialect_options["postgresql"]["using"] if using: text += "USING %s " % preparer.quote(using) ops = index.dialect_options["postgresql"]["ops"] - text += "(%s)" \ - % ( - ', '.join([ - self.sql_compiler.process( - expr.self_group() - if not isinstance(expr, expression.ColumnClause) - else expr, - include_table=False, literal_binds=True) + - ( - (' ' + ops[expr.key]) - if hasattr(expr, 'key') - and expr.key in ops else '' - ) - for expr in index.expressions - ]) - ) + text += "(%s)" % ( + ", ".join( + [ + self.sql_compiler.process( + expr.self_group() + if not isinstance(expr, expression.ColumnClause) + else expr, + include_table=False, + literal_binds=True, + ) + + ( + (" " + ops[expr.key]) + if hasattr(expr, "key") and expr.key in ops + else "" + ) + for expr in index.expressions + ] + ) + ) - withclause = index.dialect_options['postgresql']['with'] + withclause = index.dialect_options["postgresql"]["with"] if withclause: - text += " WITH (%s)" % (', '.join( - ['%s = %s' % storage_parameter - for storage_parameter in withclause.items()])) + text += " WITH (%s)" % ( + ", ".join( + [ + "%s = %s" % storage_parameter + for storage_parameter in withclause.items() + ] + ) + ) - tablespace_name = index.dialect_options['postgresql']['tablespace'] + tablespace_name = index.dialect_options["postgresql"]["tablespace"] if tablespace_name: text += " TABLESPACE %s" % preparer.quote(tablespace_name) @@ -1812,8 +1945,8 @@ class PGDDLCompiler(compiler.DDLCompiler): if whereclause is not None: where_compiled = self.sql_compiler.process( - whereclause, include_table=False, - literal_binds=True) + whereclause, include_table=False, literal_binds=True + ) text += " WHERE " + where_compiled return text @@ -1823,7 +1956,7 @@ class PGDDLCompiler(compiler.DDLCompiler): text = "\nDROP INDEX " if self.dialect._supports_drop_index_concurrently: - concurrently = index.dialect_options['postgresql']['concurrently'] + concurrently = index.dialect_options["postgresql"]["concurrently"] if concurrently: text += "CONCURRENTLY " @@ -1833,55 +1966,59 @@ class PGDDLCompiler(compiler.DDLCompiler): def visit_exclude_constraint(self, constraint, **kw): text = "" if constraint.name is not None: - text += "CONSTRAINT %s " % \ - self.preparer.format_constraint(constraint) + text += "CONSTRAINT %s " % self.preparer.format_constraint( + constraint + ) elements = [] for expr, name, op in constraint._render_exprs: - kw['include_table'] = False + kw["include_table"] = False elements.append( "%s WITH %s" % (self.sql_compiler.process(expr, **kw), op) ) - text += "EXCLUDE USING %s (%s)" % (constraint.using, - ', '.join(elements)) + text += "EXCLUDE USING %s (%s)" % ( + constraint.using, + ", ".join(elements), + ) if constraint.where is not None: - text += ' WHERE (%s)' % self.sql_compiler.process( - constraint.where, - literal_binds=True) + text += " WHERE (%s)" % self.sql_compiler.process( + constraint.where, literal_binds=True + ) text += self.define_constraint_deferrability(constraint) return text def post_create_table(self, table): table_opts = [] - pg_opts = table.dialect_options['postgresql'] + pg_opts = table.dialect_options["postgresql"] - inherits = pg_opts.get('inherits') + inherits = pg_opts.get("inherits") if inherits is not None: if not isinstance(inherits, (list, tuple)): - inherits = (inherits, ) + inherits = (inherits,) table_opts.append( - '\n INHERITS ( ' + - ', '.join(self.preparer.quote(name) for name in inherits) + - ' )') + "\n INHERITS ( " + + ", ".join(self.preparer.quote(name) for name in inherits) + + " )" + ) - if pg_opts['partition_by']: - table_opts.append('\n PARTITION BY %s' % pg_opts['partition_by']) + if pg_opts["partition_by"]: + table_opts.append("\n PARTITION BY %s" % pg_opts["partition_by"]) - if pg_opts['with_oids'] is True: - table_opts.append('\n WITH OIDS') - elif pg_opts['with_oids'] is False: - table_opts.append('\n WITHOUT OIDS') + if pg_opts["with_oids"] is True: + table_opts.append("\n WITH OIDS") + elif pg_opts["with_oids"] is False: + table_opts.append("\n WITHOUT OIDS") - if pg_opts['on_commit']: - on_commit_options = pg_opts['on_commit'].replace("_", " ").upper() - table_opts.append('\n ON COMMIT %s' % on_commit_options) + if pg_opts["on_commit"]: + on_commit_options = pg_opts["on_commit"].replace("_", " ").upper() + table_opts.append("\n ON COMMIT %s" % on_commit_options) - if pg_opts['tablespace']: - tablespace_name = pg_opts['tablespace'] + if pg_opts["tablespace"]: + tablespace_name = pg_opts["tablespace"] table_opts.append( - '\n TABLESPACE %s' % self.preparer.quote(tablespace_name) + "\n TABLESPACE %s" % self.preparer.quote(tablespace_name) ) - return ''.join(table_opts) + return "".join(table_opts) class PGTypeCompiler(compiler.GenericTypeCompiler): @@ -1910,7 +2047,7 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): if not type_.precision: return "FLOAT" else: - return "FLOAT(%(precision)s)" % {'precision': type_.precision} + return "FLOAT(%(precision)s)" % {"precision": type_.precision} def visit_DOUBLE_PRECISION(self, type_, **kw): return "DOUBLE PRECISION" @@ -1960,15 +2097,17 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): def visit_TIMESTAMP(self, type_, **kw): return "TIMESTAMP%s %s" % ( "(%d)" % type_.precision - if getattr(type_, 'precision', None) is not None else "", - (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + if getattr(type_, "precision", None) is not None + else "", + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", ) def visit_TIME(self, type_, **kw): return "TIME%s %s" % ( "(%d)" % type_.precision - if getattr(type_, 'precision', None) is not None else "", - (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE" + if getattr(type_, "precision", None) is not None + else "", + (type_.timezone and "WITH" or "WITHOUT") + " TIME ZONE", ) def visit_INTERVAL(self, type_, **kw): @@ -2002,13 +2141,16 @@ class PGTypeCompiler(compiler.GenericTypeCompiler): # TODO: pass **kw? inner = self.process(type_.item_type) return re.sub( - r'((?: COLLATE.*)?)$', - (r'%s\1' % ( - "[]" * - (type_.dimensions if type_.dimensions is not None else 1) - )), + r"((?: COLLATE.*)?)$", + ( + r"%s\1" + % ( + "[]" + * (type_.dimensions if type_.dimensions is not None else 1) + ) + ), inner, - count=1 + count=1, ) @@ -2018,8 +2160,9 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): def _unquote_identifier(self, value): if value[0] == self.initial_quote: - value = value[1:-1].\ - replace(self.escape_to_quote, self.escape_quote) + value = value[1:-1].replace( + self.escape_to_quote, self.escape_quote + ) return value def format_type(self, type_, use_schema=True): @@ -2029,22 +2172,25 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): name = self.quote(type_.name) effective_schema = self.schema_for_object(type_) - if not self.omit_schema and use_schema and \ - effective_schema is not None: + if ( + not self.omit_schema + and use_schema + and effective_schema is not None + ): name = self.quote_schema(effective_schema) + "." + name return name class PGInspector(reflection.Inspector): - def __init__(self, conn): reflection.Inspector.__init__(self, conn) def get_table_oid(self, table_name, schema=None): """Return the OID for the given table name.""" - return self.dialect.get_table_oid(self.bind, table_name, schema, - info_cache=self.info_cache) + return self.dialect.get_table_oid( + self.bind, table_name, schema, info_cache=self.info_cache + ) def get_enums(self, schema=None): """Return a list of ENUM objects. @@ -2080,7 +2226,7 @@ class PGInspector(reflection.Inspector): schema = schema or self.default_schema_name return self.dialect._get_foreign_table_names(self.bind, schema) - def get_view_names(self, schema=None, include=('plain', 'materialized')): + def get_view_names(self, schema=None, include=("plain", "materialized")): """Return all view names in `schema`. :param schema: Optional, retrieve names from a non-default schema. @@ -2094,9 +2240,9 @@ class PGInspector(reflection.Inspector): """ - return self.dialect.get_view_names(self.bind, schema, - info_cache=self.info_cache, - include=include) + return self.dialect.get_view_names( + self.bind, schema, info_cache=self.info_cache, include=include + ) class CreateEnumType(schema._CreateDropBase): @@ -2108,25 +2254,27 @@ class DropEnumType(schema._CreateDropBase): class PGExecutionContext(default.DefaultExecutionContext): - def fire_sequence(self, seq, type_): - return self._execute_scalar(( - "select nextval('%s')" % - self.dialect.identifier_preparer.format_sequence(seq)), type_) + return self._execute_scalar( + ( + "select nextval('%s')" + % self.dialect.identifier_preparer.format_sequence(seq) + ), + type_, + ) def get_insert_default(self, column): - if column.primary_key and \ - column is column.table._autoincrement_column: + if column.primary_key and column is column.table._autoincrement_column: if column.server_default and column.server_default.has_argument: # pre-execute passive defaults on primary key columns - return self._execute_scalar("select %s" % - column.server_default.arg, - column.type) + return self._execute_scalar( + "select %s" % column.server_default.arg, column.type + ) - elif (column.default is None or - (column.default.is_sequence and - column.default.optional)): + elif column.default is None or ( + column.default.is_sequence and column.default.optional + ): # execute the sequence associated with a SERIAL primary # key column. for non-primary-key SERIAL, the ID just @@ -2137,23 +2285,25 @@ class PGExecutionContext(default.DefaultExecutionContext): except AttributeError: tab = column.table.name col = column.name - tab = tab[0:29 + max(0, (29 - len(col)))] - col = col[0:29 + max(0, (29 - len(tab)))] + tab = tab[0 : 29 + max(0, (29 - len(col)))] + col = col[0 : 29 + max(0, (29 - len(tab)))] name = "%s_%s_seq" % (tab, col) column._postgresql_seq_name = seq_name = name if column.table is not None: effective_schema = self.connection.schema_for_object( - column.table) + column.table + ) else: effective_schema = None if effective_schema is not None: - exc = "select nextval('\"%s\".\"%s\"')" % \ - (effective_schema, seq_name) + exc = 'select nextval(\'"%s"."%s"\')' % ( + effective_schema, + seq_name, + ) else: - exc = "select nextval('\"%s\"')" % \ - (seq_name, ) + exc = "select nextval('\"%s\"')" % (seq_name,) return self._execute_scalar(exc, column.type) @@ -2164,7 +2314,7 @@ class PGExecutionContext(default.DefaultExecutionContext): class PGDialect(default.DefaultDialect): - name = 'postgresql' + name = "postgresql" supports_alter = True max_identifier_length = 63 supports_sane_rowcount = True @@ -2182,7 +2332,7 @@ class PGDialect(default.DefaultDialect): supports_default_values = True supports_empty_insert = False supports_multivalues_insert = True - default_paramstyle = 'pyformat' + default_paramstyle = "pyformat" ischema_names = ischema_names colspecs = colspecs @@ -2195,32 +2345,43 @@ class PGDialect(default.DefaultDialect): isolation_level = None construct_arguments = [ - (schema.Index, { - "using": False, - "where": None, - "ops": {}, - "concurrently": False, - "with": {}, - "tablespace": None - }), - (schema.Table, { - "ignore_search_path": False, - "tablespace": None, - "partition_by": None, - "with_oids": None, - "on_commit": None, - "inherits": None - }), + ( + schema.Index, + { + "using": False, + "where": None, + "ops": {}, + "concurrently": False, + "with": {}, + "tablespace": None, + }, + ), + ( + schema.Table, + { + "ignore_search_path": False, + "tablespace": None, + "partition_by": None, + "with_oids": None, + "on_commit": None, + "inherits": None, + }, + ), ] - reflection_options = ('postgresql_ignore_search_path', ) + reflection_options = ("postgresql_ignore_search_path",) _backslash_escapes = True _supports_create_index_concurrently = True _supports_drop_index_concurrently = True - def __init__(self, isolation_level=None, json_serializer=None, - json_deserializer=None, **kwargs): + def __init__( + self, + isolation_level=None, + json_serializer=None, + json_deserializer=None, + **kwargs + ): default.DefaultDialect.__init__(self, **kwargs) self.isolation_level = isolation_level self._json_deserializer = json_deserializer @@ -2228,8 +2389,10 @@ class PGDialect(default.DefaultDialect): def initialize(self, connection): super(PGDialect, self).initialize(connection) - self.implicit_returning = self.server_version_info > (8, 2) and \ - self.__dict__.get('implicit_returning', True) + self.implicit_returning = self.server_version_info > ( + 8, + 2, + ) and self.__dict__.get("implicit_returning", True) self.supports_native_enum = self.server_version_info >= (8, 3) if not self.supports_native_enum: self.colspecs = self.colspecs.copy() @@ -2241,45 +2404,57 @@ class PGDialect(default.DefaultDialect): # http://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689 self.supports_smallserial = self.server_version_info >= (9, 2) - self._backslash_escapes = self.server_version_info < (8, 2) or \ - connection.scalar( - "show standard_conforming_strings" - ) == 'off' + self._backslash_escapes = ( + self.server_version_info < (8, 2) + or connection.scalar("show standard_conforming_strings") == "off" + ) - self._supports_create_index_concurrently = \ + self._supports_create_index_concurrently = ( self.server_version_info >= (8, 2) - self._supports_drop_index_concurrently = \ - self.server_version_info >= (9, 2) + ) + self._supports_drop_index_concurrently = self.server_version_info >= ( + 9, + 2, + ) def on_connect(self): if self.isolation_level is not None: + def connect(conn): self.set_isolation_level(conn, self.isolation_level) + return connect else: return None - _isolation_lookup = set(['SERIALIZABLE', 'READ UNCOMMITTED', - 'READ COMMITTED', 'REPEATABLE READ']) + _isolation_lookup = set( + [ + "SERIALIZABLE", + "READ UNCOMMITTED", + "READ COMMITTED", + "REPEATABLE READ", + ] + ) def set_isolation_level(self, connection, level): - level = level.replace('_', ' ') + level = level.replace("_", " ") if level not in self._isolation_lookup: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) ) cursor = connection.cursor() cursor.execute( "SET SESSION CHARACTERISTICS AS TRANSACTION " - "ISOLATION LEVEL %s" % level) + "ISOLATION LEVEL %s" % level + ) cursor.execute("COMMIT") cursor.close() def get_isolation_level(self, connection): cursor = connection.cursor() - cursor.execute('show transaction isolation level') + cursor.execute("show transaction isolation level") val = cursor.fetchone()[0] cursor.close() return val.upper() @@ -2290,8 +2465,9 @@ class PGDialect(default.DefaultDialect): def do_prepare_twophase(self, connection, xid): connection.execute("PREPARE TRANSACTION '%s'" % xid) - def do_rollback_twophase(self, connection, xid, - is_prepared=True, recover=False): + def do_rollback_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if is_prepared: if recover: # FIXME: ugly hack to get out of transaction @@ -2305,8 +2481,9 @@ class PGDialect(default.DefaultDialect): else: self.do_rollback(connection.connection) - def do_commit_twophase(self, connection, xid, - is_prepared=True, recover=False): + def do_commit_twophase( + self, connection, xid, is_prepared=True, recover=False + ): if is_prepared: if recover: connection.execute("ROLLBACK") @@ -2318,22 +2495,27 @@ class PGDialect(default.DefaultDialect): def do_recover_twophase(self, connection): resultset = connection.execute( - sql.text("SELECT gid FROM pg_prepared_xacts")) + sql.text("SELECT gid FROM pg_prepared_xacts") + ) return [row[0] for row in resultset] def _get_default_schema_name(self, connection): return connection.scalar("select current_schema()") def has_schema(self, connection, schema): - query = ("select nspname from pg_namespace " - "where lower(nspname)=:schema") + query = ( + "select nspname from pg_namespace " "where lower(nspname)=:schema" + ) cursor = connection.execute( sql.text( query, bindparams=[ sql.bindparam( - 'schema', util.text_type(schema.lower()), - type_=sqltypes.Unicode)] + "schema", + util.text_type(schema.lower()), + type_=sqltypes.Unicode, + ) + ], ) ) @@ -2349,8 +2531,12 @@ class PGDialect(default.DefaultDialect): "pg_catalog.pg_table_is_visible(c.oid) " "and relname=:name", bindparams=[ - sql.bindparam('name', util.text_type(table_name), - type_=sqltypes.Unicode)] + sql.bindparam( + "name", + util.text_type(table_name), + type_=sqltypes.Unicode, + ) + ], ) ) else: @@ -2360,12 +2546,17 @@ class PGDialect(default.DefaultDialect): "n.oid=c.relnamespace where n.nspname=:schema and " "relname=:name", bindparams=[ - sql.bindparam('name', - util.text_type(table_name), - type_=sqltypes.Unicode), - sql.bindparam('schema', - util.text_type(schema), - type_=sqltypes.Unicode)] + sql.bindparam( + "name", + util.text_type(table_name), + type_=sqltypes.Unicode, + ), + sql.bindparam( + "schema", + util.text_type(schema), + type_=sqltypes.Unicode, + ), + ], ) ) return bool(cursor.first()) @@ -2379,9 +2570,12 @@ class PGDialect(default.DefaultDialect): "n.nspname=current_schema() " "and relname=:name", bindparams=[ - sql.bindparam('name', util.text_type(sequence_name), - type_=sqltypes.Unicode) - ] + sql.bindparam( + "name", + util.text_type(sequence_name), + type_=sqltypes.Unicode, + ) + ], ) ) else: @@ -2391,12 +2585,17 @@ class PGDialect(default.DefaultDialect): "n.oid=c.relnamespace where relkind='S' and " "n.nspname=:schema and relname=:name", bindparams=[ - sql.bindparam('name', util.text_type(sequence_name), - type_=sqltypes.Unicode), - sql.bindparam('schema', - util.text_type(schema), - type_=sqltypes.Unicode) - ] + sql.bindparam( + "name", + util.text_type(sequence_name), + type_=sqltypes.Unicode, + ), + sql.bindparam( + "schema", + util.text_type(schema), + type_=sqltypes.Unicode, + ), + ], ) ) @@ -2423,13 +2622,15 @@ class PGDialect(default.DefaultDialect): """ query = sql.text(query) query = query.bindparams( - sql.bindparam('typname', - util.text_type(type_name), type_=sqltypes.Unicode), + sql.bindparam( + "typname", util.text_type(type_name), type_=sqltypes.Unicode + ) ) if schema is not None: query = query.bindparams( - sql.bindparam('nspname', - util.text_type(schema), type_=sqltypes.Unicode), + sql.bindparam( + "nspname", util.text_type(schema), type_=sqltypes.Unicode + ) ) cursor = connection.execute(query) return bool(cursor.scalar()) @@ -2437,12 +2638,14 @@ class PGDialect(default.DefaultDialect): def _get_server_version_info(self, connection): v = connection.execute("select version()").scalar() m = re.match( - r'.*(?:PostgreSQL|EnterpriseDB) ' - r'(\d+)\.?(\d+)?(?:\.(\d+))?(?:\.\d+)?(?:devel|beta)?', - v) + r".*(?:PostgreSQL|EnterpriseDB) " + r"(\d+)\.?(\d+)?(?:\.(\d+))?(?:\.\d+)?(?:devel|beta)?", + v, + ) if not m: raise AssertionError( - "Could not determine version from string '%s'" % v) + "Could not determine version from string '%s'" % v + ) return tuple([int(x) for x in m.group(1, 2, 3) if x is not None]) @reflection.cache @@ -2459,14 +2662,17 @@ class PGDialect(default.DefaultDialect): schema_where_clause = "n.nspname = :schema" else: schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" - query = """ + query = ( + """ SELECT c.oid FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE (%s) AND c.relname = :table_name AND c.relkind in ('r', 'v', 'm', 'f', 'p') - """ % schema_where_clause + """ + % schema_where_clause + ) # Since we're binding to unicode, table_name and schema_name must be # unicode. table_name = util.text_type(table_name) @@ -2475,7 +2681,7 @@ class PGDialect(default.DefaultDialect): s = sql.text(query).bindparams(table_name=sqltypes.Unicode) s = s.columns(oid=sqltypes.Integer) if schema: - s = s.bindparams(sql.bindparam('schema', type_=sqltypes.Unicode)) + s = s.bindparams(sql.bindparam("schema", type_=sqltypes.Unicode)) c = connection.execute(s, table_name=table_name, schema=schema) table_oid = c.scalar() if table_oid is None: @@ -2485,75 +2691,88 @@ class PGDialect(default.DefaultDialect): @reflection.cache def get_schema_names(self, connection, **kw): result = connection.execute( - sql.text("SELECT nspname FROM pg_namespace " - "WHERE nspname NOT LIKE 'pg_%' " - "ORDER BY nspname" - ).columns(nspname=sqltypes.Unicode)) + sql.text( + "SELECT nspname FROM pg_namespace " + "WHERE nspname NOT LIKE 'pg_%' " + "ORDER BY nspname" + ).columns(nspname=sqltypes.Unicode) + ) return [name for name, in result] @reflection.cache def get_table_names(self, connection, schema=None, **kw): result = connection.execute( - sql.text("SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')" - ).columns(relname=sqltypes.Unicode), - schema=schema if schema is not None else self.default_schema_name) + sql.text( + "SELECT c.relname FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')" + ).columns(relname=sqltypes.Unicode), + schema=schema if schema is not None else self.default_schema_name, + ) return [name for name, in result] @reflection.cache def _get_foreign_table_names(self, connection, schema=None, **kw): result = connection.execute( - sql.text("SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind = 'f'" - ).columns(relname=sqltypes.Unicode), - schema=schema if schema is not None else self.default_schema_name) + sql.text( + "SELECT c.relname FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relkind = 'f'" + ).columns(relname=sqltypes.Unicode), + schema=schema if schema is not None else self.default_schema_name, + ) return [name for name, in result] @reflection.cache def get_view_names( - self, connection, schema=None, - include=('plain', 'materialized'), **kw): + self, connection, schema=None, include=("plain", "materialized"), **kw + ): - include_kind = {'plain': 'v', 'materialized': 'm'} + include_kind = {"plain": "v", "materialized": "m"} try: kinds = [include_kind[i] for i in util.to_list(include)] except KeyError: raise ValueError( "include %r unknown, needs to be a sequence containing " - "one or both of 'plain' and 'materialized'" % (include,)) + "one or both of 'plain' and 'materialized'" % (include,) + ) if not kinds: raise ValueError( "empty include, needs to be a sequence containing " - "one or both of 'plain' and 'materialized'") + "one or both of 'plain' and 'materialized'" + ) result = connection.execute( - sql.text("SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind IN (%s)" % - (", ".join("'%s'" % elem for elem in kinds)) - ).columns(relname=sqltypes.Unicode), - schema=schema if schema is not None else self.default_schema_name) + sql.text( + "SELECT c.relname FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relkind IN (%s)" + % (", ".join("'%s'" % elem for elem in kinds)) + ).columns(relname=sqltypes.Unicode), + schema=schema if schema is not None else self.default_schema_name, + ) return [name for name, in result] @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): view_def = connection.scalar( - sql.text("SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relname = :view_name " - "AND c.relkind IN ('v', 'm')" - ).columns(view_def=sqltypes.Unicode), + sql.text( + "SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c " + "JOIN pg_namespace n ON n.oid = c.relnamespace " + "WHERE n.nspname = :schema AND c.relname = :view_name " + "AND c.relkind IN ('v', 'm')" + ).columns(view_def=sqltypes.Unicode), schema=schema if schema is not None else self.default_schema_name, - view_name=view_name) + view_name=view_name, + ) return view_def @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid(connection, table_name, schema, - info_cache=kw.get('info_cache')) + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) SQL_COLS = """ SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), @@ -2571,13 +2790,11 @@ class PGDialect(default.DefaultDialect): AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum """ - s = sql.text(SQL_COLS, - bindparams=[ - sql.bindparam('table_oid', type_=sqltypes.Integer)], - typemap={ - 'attname': sqltypes.Unicode, - 'default': sqltypes.Unicode} - ) + s = sql.text( + SQL_COLS, + bindparams=[sql.bindparam("table_oid", type_=sqltypes.Integer)], + typemap={"attname": sqltypes.Unicode, "default": sqltypes.Unicode}, + ) c = connection.execute(s, table_oid=table_oid) rows = c.fetchall() @@ -2588,34 +2805,58 @@ class PGDialect(default.DefaultDialect): # dictionary with (name, ) if default search path or (schema, name) # as keys enums = dict( - ((rec['name'], ), rec) - if rec['visible'] else ((rec['schema'], rec['name']), rec) - for rec in self._load_enums(connection, schema='*') + ((rec["name"],), rec) + if rec["visible"] + else ((rec["schema"], rec["name"]), rec) + for rec in self._load_enums(connection, schema="*") ) # format columns columns = [] - for name, format_type, default_, notnull, attnum, table_oid, \ - comment in rows: + for ( + name, + format_type, + default_, + notnull, + attnum, + table_oid, + comment, + ) in rows: column_info = self._get_column_info( - name, format_type, default_, notnull, domains, enums, - schema, comment) + name, + format_type, + default_, + notnull, + domains, + enums, + schema, + comment, + ) columns.append(column_info) return columns - def _get_column_info(self, name, format_type, default, - notnull, domains, enums, schema, comment): + def _get_column_info( + self, + name, + format_type, + default, + notnull, + domains, + enums, + schema, + comment, + ): def _handle_array_type(attype): return ( # strip '[]' from integer[], etc. - re.sub(r'\[\]$', '', attype), - attype.endswith('[]'), + re.sub(r"\[\]$", "", attype), + attype.endswith("[]"), ) # strip (*) from character varying(5), timestamp(5) # with time zone, geometry(POLYGON), etc. - attype = re.sub(r'\(.*\)', '', format_type) + attype = re.sub(r"\(.*\)", "", format_type) # strip '[]' from integer[], etc. and check if an array attype, is_array = _handle_array_type(attype) @@ -2625,50 +2866,52 @@ class PGDialect(default.DefaultDialect): nullable = not notnull - charlen = re.search(r'\(([\d,]+)\)', format_type) + charlen = re.search(r"\(([\d,]+)\)", format_type) if charlen: charlen = charlen.group(1) - args = re.search(r'\((.*)\)', format_type) + args = re.search(r"\((.*)\)", format_type) if args and args.group(1): - args = tuple(re.split(r'\s*,\s*', args.group(1))) + args = tuple(re.split(r"\s*,\s*", args.group(1))) else: args = () kwargs = {} - if attype == 'numeric': + if attype == "numeric": if charlen: - prec, scale = charlen.split(',') + prec, scale = charlen.split(",") args = (int(prec), int(scale)) else: args = () - elif attype == 'double precision': - args = (53, ) - elif attype == 'integer': + elif attype == "double precision": + args = (53,) + elif attype == "integer": args = () - elif attype in ('timestamp with time zone', - 'time with time zone'): - kwargs['timezone'] = True + elif attype in ("timestamp with time zone", "time with time zone"): + kwargs["timezone"] = True if charlen: - kwargs['precision'] = int(charlen) + kwargs["precision"] = int(charlen) args = () - elif attype in ('timestamp without time zone', - 'time without time zone', 'time'): - kwargs['timezone'] = False + elif attype in ( + "timestamp without time zone", + "time without time zone", + "time", + ): + kwargs["timezone"] = False if charlen: - kwargs['precision'] = int(charlen) + kwargs["precision"] = int(charlen) args = () - elif attype == 'bit varying': - kwargs['varying'] = True + elif attype == "bit varying": + kwargs["varying"] = True if charlen: args = (int(charlen),) else: args = () - elif attype.startswith('interval'): - field_match = re.match(r'interval (.+)', attype, re.I) + elif attype.startswith("interval"): + field_match = re.match(r"interval (.+)", attype, re.I) if charlen: - kwargs['precision'] = int(charlen) + kwargs["precision"] = int(charlen) if field_match: - kwargs['fields'] = field_match.group(1) + kwargs["fields"] = field_match.group(1) attype = "interval" args = () elif charlen: @@ -2682,23 +2925,23 @@ class PGDialect(default.DefaultDialect): elif enum_or_domain_key in enums: enum = enums[enum_or_domain_key] coltype = ENUM - kwargs['name'] = enum['name'] - if not enum['visible']: - kwargs['schema'] = enum['schema'] - args = tuple(enum['labels']) + kwargs["name"] = enum["name"] + if not enum["visible"]: + kwargs["schema"] = enum["schema"] + args = tuple(enum["labels"]) break elif enum_or_domain_key in domains: domain = domains[enum_or_domain_key] - attype = domain['attype'] + attype = domain["attype"] attype, is_array = _handle_array_type(attype) # strip quotes from case sensitive enum or domain names enum_or_domain_key = tuple(util.quoted_token_parser(attype)) # A table can't override whether the domain is nullable. - nullable = domain['nullable'] - if domain['default'] and not default: + nullable = domain["nullable"] + if domain["default"] and not default: # It can, however, override the default # value, but can't set it to null. - default = domain['default'] + default = domain["default"] continue else: coltype = None @@ -2707,10 +2950,11 @@ class PGDialect(default.DefaultDialect): if coltype: coltype = coltype(*args, **kwargs) if is_array: - coltype = self.ischema_names['_array'](coltype) + coltype = self.ischema_names["_array"](coltype) else: - util.warn("Did not recognize type '%s' of column '%s'" % - (attype, name)) + util.warn( + "Did not recognize type '%s' of column '%s'" % (attype, name) + ) coltype = sqltypes.NULLTYPE # adjust the default value autoincrement = False @@ -2721,23 +2965,33 @@ class PGDialect(default.DefaultDialect): autoincrement = True # the default is related to a Sequence sch = schema - if '.' not in match.group(2) and sch is not None: + if "." not in match.group(2) and sch is not None: # unconditionally quote the schema name. this could # later be enhanced to obey quoting rules / # "quote schema" - default = match.group(1) + \ - ('"%s"' % sch) + '.' + \ - match.group(2) + match.group(3) + default = ( + match.group(1) + + ('"%s"' % sch) + + "." + + match.group(2) + + match.group(3) + ) - column_info = dict(name=name, type=coltype, nullable=nullable, - default=default, autoincrement=autoincrement, - comment=comment) + column_info = dict( + name=name, + type=coltype, + nullable=nullable, + default=default, + autoincrement=autoincrement, + comment=comment, + ) return column_info @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid(connection, table_name, schema, - info_cache=kw.get('info_cache')) + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) if self.server_version_info < (8, 4): PK_SQL = """ @@ -2750,7 +3004,9 @@ class PGDialect(default.DefaultDialect): WHERE t.oid = :table_oid and ix.indisprimary = 't' ORDER BY a.attnum - """ % self._pg_index_any("a.attnum", "ix.indkey") + """ % self._pg_index_any( + "a.attnum", "ix.indkey" + ) else: # unnest() and generate_subscripts() both introduced in @@ -2766,7 +3022,7 @@ class PGDialect(default.DefaultDialect): WHERE a.attrelid = :table_oid ORDER BY k.ord """ - t = sql.text(PK_SQL, typemap={'attname': sqltypes.Unicode}) + t = sql.text(PK_SQL, typemap={"attname": sqltypes.Unicode}) c = connection.execute(t, table_oid=table_oid) cols = [r[0] for r in c.fetchall()] @@ -2776,18 +3032,25 @@ class PGDialect(default.DefaultDialect): WHERE r.conrelid = :table_oid AND r.contype = 'p' ORDER BY 1 """ - t = sql.text(PK_CONS_SQL, typemap={'conname': sqltypes.Unicode}) + t = sql.text(PK_CONS_SQL, typemap={"conname": sqltypes.Unicode}) c = connection.execute(t, table_oid=table_oid) name = c.scalar() - return {'constrained_columns': cols, 'name': name} + return {"constrained_columns": cols, "name": name} @reflection.cache - def get_foreign_keys(self, connection, table_name, schema=None, - postgresql_ignore_search_path=False, **kw): + def get_foreign_keys( + self, + connection, + table_name, + schema=None, + postgresql_ignore_search_path=False, + **kw + ): preparer = self.identifier_preparer - table_oid = self.get_table_oid(connection, table_name, schema, - info_cache=kw.get('info_cache')) + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) FK_SQL = """ SELECT r.conname, @@ -2805,34 +3068,35 @@ class PGDialect(default.DefaultDialect): """ # http://www.postgresql.org/docs/9.0/static/sql-createtable.html FK_REGEX = re.compile( - r'FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)' - r'[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?' - r'[\s]?(ON UPDATE ' - r'(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?' - r'[\s]?(ON DELETE ' - r'(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?' - r'[\s]?(DEFERRABLE|NOT DEFERRABLE)?' - r'[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?' + r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)" + r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?" + r"[\s]?(ON UPDATE " + r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?" + r"[\s]?(ON DELETE " + r"(CASCADE|RESTRICT|NO ACTION|SET NULL|SET DEFAULT)+)?" + r"[\s]?(DEFERRABLE|NOT DEFERRABLE)?" + r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?" ) - t = sql.text(FK_SQL, typemap={ - 'conname': sqltypes.Unicode, - 'condef': sqltypes.Unicode}) + t = sql.text( + FK_SQL, + typemap={"conname": sqltypes.Unicode, "condef": sqltypes.Unicode}, + ) c = connection.execute(t, table=table_oid) fkeys = [] for conname, condef, conschema in c.fetchall(): m = re.search(FK_REGEX, condef).groups() - constrained_columns, referred_schema, \ - referred_table, referred_columns, \ - _, match, _, onupdate, _, ondelete, \ - deferrable, _, initially = m + constrained_columns, referred_schema, referred_table, referred_columns, _, match, _, onupdate, _, ondelete, deferrable, _, initially = ( + m + ) if deferrable is not None: - deferrable = True if deferrable == 'DEFERRABLE' else False - constrained_columns = [preparer._unquote_identifier(x) - for x in re.split( - r'\s*,\s*', constrained_columns)] + deferrable = True if deferrable == "DEFERRABLE" else False + constrained_columns = [ + preparer._unquote_identifier(x) + for x in re.split(r"\s*,\s*", constrained_columns) + ] if postgresql_ignore_search_path: # when ignoring search path, we use the actual schema @@ -2845,30 +3109,30 @@ class PGDialect(default.DefaultDialect): # referred_schema is the schema that we regexp'ed from # pg_get_constraintdef(). If the schema is in the search # path, pg_get_constraintdef() will give us None. - referred_schema = \ - preparer._unquote_identifier(referred_schema) + referred_schema = preparer._unquote_identifier(referred_schema) elif schema is not None and schema == conschema: # If the actual schema matches the schema of the table # we're reflecting, then we will use that. referred_schema = schema referred_table = preparer._unquote_identifier(referred_table) - referred_columns = [preparer._unquote_identifier(x) - for x in - re.split(r'\s*,\s', referred_columns)] + referred_columns = [ + preparer._unquote_identifier(x) + for x in re.split(r"\s*,\s", referred_columns) + ] fkey_d = { - 'name': conname, - 'constrained_columns': constrained_columns, - 'referred_schema': referred_schema, - 'referred_table': referred_table, - 'referred_columns': referred_columns, - 'options': { - 'onupdate': onupdate, - 'ondelete': ondelete, - 'deferrable': deferrable, - 'initially': initially, - 'match': match - } + "name": conname, + "constrained_columns": constrained_columns, + "referred_schema": referred_schema, + "referred_table": referred_table, + "referred_columns": referred_columns, + "options": { + "onupdate": onupdate, + "ondelete": ondelete, + "deferrable": deferrable, + "initially": initially, + "match": match, + }, } fkeys.append(fkey_d) return fkeys @@ -2882,16 +3146,16 @@ class PGDialect(default.DefaultDialect): # for now. # regards, tom lane" return "(%s)" % " OR ".join( - "%s[%d] = %s" % (compare_to, ind, col) - for ind in range(0, 10) + "%s[%d] = %s" % (compare_to, ind, col) for ind in range(0, 10) ) else: return "%s = ANY(%s)" % (col, compare_to) @reflection.cache def get_indexes(self, connection, table_name, schema, **kw): - table_oid = self.get_table_oid(connection, table_name, schema, - info_cache=kw.get('info_cache')) + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) # cast indkey as varchar since it's an int2vector, # returned as a list by some drivers such as pypostgresql @@ -2925,9 +3189,10 @@ class PGDialect(default.DefaultDialect): # cast does not work in PG 8.2.4, does work in 8.3.0. # nothing in PG changelogs regarding this. "::varchar" if self.server_version_info >= (8, 3) else "", - "i.reloptions" if self.server_version_info >= (8, 2) + "i.reloptions" + if self.server_version_info >= (8, 2) else "NULL", - self._pg_index_any("a.attnum", "ix.indkey") + self._pg_index_any("a.attnum", "ix.indkey"), ) else: IDX_SQL = """ @@ -2960,76 +3225,93 @@ class PGDialect(default.DefaultDialect): i.relname """ - t = sql.text(IDX_SQL, typemap={ - 'relname': sqltypes.Unicode, - 'attname': sqltypes.Unicode}) + t = sql.text( + IDX_SQL, + typemap={"relname": sqltypes.Unicode, "attname": sqltypes.Unicode}, + ) c = connection.execute(t, table_oid=table_oid) indexes = defaultdict(lambda: defaultdict(dict)) sv_idx_name = None for row in c.fetchall(): - (idx_name, unique, expr, prd, col, - col_num, conrelid, idx_key, options, amname) = row + ( + idx_name, + unique, + expr, + prd, + col, + col_num, + conrelid, + idx_key, + options, + amname, + ) = row if expr: if idx_name != sv_idx_name: util.warn( "Skipped unsupported reflection of " - "expression-based index %s" - % idx_name) + "expression-based index %s" % idx_name + ) sv_idx_name = idx_name continue if prd and not idx_name == sv_idx_name: util.warn( "Predicate of partial index %s ignored during reflection" - % idx_name) + % idx_name + ) sv_idx_name = idx_name has_idx = idx_name in indexes index = indexes[idx_name] if col is not None: - index['cols'][col_num] = col + index["cols"][col_num] = col if not has_idx: - index['key'] = [int(k.strip()) for k in idx_key.split()] - index['unique'] = unique + index["key"] = [int(k.strip()) for k in idx_key.split()] + index["unique"] = unique if conrelid is not None: - index['duplicates_constraint'] = idx_name + index["duplicates_constraint"] = idx_name if options: - index['options'] = dict( - [option.split("=") for option in options]) + index["options"] = dict( + [option.split("=") for option in options] + ) # it *might* be nice to include that this is 'btree' in the # reflection info. But we don't want an Index object # to have a ``postgresql_using`` in it that is just the # default, so for the moment leaving this out. - if amname and amname != 'btree': - index['amname'] = amname + if amname and amname != "btree": + index["amname"] = amname result = [] for name, idx in indexes.items(): entry = { - 'name': name, - 'unique': idx['unique'], - 'column_names': [idx['cols'][i] for i in idx['key']] + "name": name, + "unique": idx["unique"], + "column_names": [idx["cols"][i] for i in idx["key"]], } - if 'duplicates_constraint' in idx: - entry['duplicates_constraint'] = idx['duplicates_constraint'] - if 'options' in idx: - entry.setdefault( - 'dialect_options', {})["postgresql_with"] = idx['options'] - if 'amname' in idx: - entry.setdefault( - 'dialect_options', {})["postgresql_using"] = idx['amname'] + if "duplicates_constraint" in idx: + entry["duplicates_constraint"] = idx["duplicates_constraint"] + if "options" in idx: + entry.setdefault("dialect_options", {})[ + "postgresql_with" + ] = idx["options"] + if "amname" in idx: + entry.setdefault("dialect_options", {})[ + "postgresql_using" + ] = idx["amname"] result.append(entry) return result @reflection.cache - def get_unique_constraints(self, connection, table_name, - schema=None, **kw): - table_oid = self.get_table_oid(connection, table_name, schema, - info_cache=kw.get('info_cache')) + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) UNIQUE_SQL = """ SELECT @@ -3047,7 +3329,7 @@ class PGDialect(default.DefaultDialect): cons.contype = 'u' """ - t = sql.text(UNIQUE_SQL, typemap={'col_name': sqltypes.Unicode}) + t = sql.text(UNIQUE_SQL, typemap={"col_name": sqltypes.Unicode}) c = connection.execute(t, table_oid=table_oid) uniques = defaultdict(lambda: defaultdict(dict)) @@ -3057,15 +3339,15 @@ class PGDialect(default.DefaultDialect): uc["cols"][row.col_num] = row.col_name return [ - {'name': name, - 'column_names': [uc["cols"][i] for i in uc["key"]]} + {"name": name, "column_names": [uc["cols"][i] for i in uc["key"]]} for name, uc in uniques.items() ] @reflection.cache def get_table_comment(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid(connection, table_name, schema, - info_cache=kw.get('info_cache')) + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) COMMENT_SQL = """ SELECT @@ -3081,10 +3363,10 @@ class PGDialect(default.DefaultDialect): return {"text": c.scalar()} @reflection.cache - def get_check_constraints( - self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid(connection, table_name, schema, - info_cache=kw.get('info_cache')) + def get_check_constraints(self, connection, table_name, schema=None, **kw): + table_oid = self.get_table_oid( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) CHECK_SQL = """ SELECT @@ -3100,10 +3382,8 @@ class PGDialect(default.DefaultDialect): c = connection.execute(sql.text(CHECK_SQL), table_oid=table_oid) return [ - {'name': name, - 'sqltext': src[1:-1]} - for name, src in c.fetchall() - ] + {"name": name, "sqltext": src[1:-1]} for name, src in c.fetchall() + ] def _load_enums(self, connection, schema=None): schema = schema or self.default_schema_name @@ -3124,17 +3404,18 @@ class PGDialect(default.DefaultDialect): WHERE t.typtype = 'e' """ - if schema != '*': + if schema != "*": SQL_ENUMS += "AND n.nspname = :schema " # e.oid gives us label order within an enum SQL_ENUMS += 'ORDER BY "schema", "name", e.oid' - s = sql.text(SQL_ENUMS, typemap={ - 'attname': sqltypes.Unicode, - 'label': sqltypes.Unicode}) + s = sql.text( + SQL_ENUMS, + typemap={"attname": sqltypes.Unicode, "label": sqltypes.Unicode}, + ) - if schema != '*': + if schema != "*": s = s.bindparams(schema=schema) c = connection.execute(s) @@ -3142,15 +3423,15 @@ class PGDialect(default.DefaultDialect): enums = [] enum_by_name = {} for enum in c.fetchall(): - key = (enum['schema'], enum['name']) + key = (enum["schema"], enum["name"]) if key in enum_by_name: - enum_by_name[key]['labels'].append(enum['label']) + enum_by_name[key]["labels"].append(enum["label"]) else: enum_by_name[key] = enum_rec = { - 'name': enum['name'], - 'schema': enum['schema'], - 'visible': enum['visible'], - 'labels': [enum['label']], + "name": enum["name"], + "schema": enum["schema"], + "visible": enum["visible"], + "labels": [enum["label"]], } enums.append(enum_rec) return enums @@ -3169,26 +3450,26 @@ class PGDialect(default.DefaultDialect): WHERE t.typtype = 'd' """ - s = sql.text(SQL_DOMAINS, typemap={'attname': sqltypes.Unicode}) + s = sql.text(SQL_DOMAINS, typemap={"attname": sqltypes.Unicode}) c = connection.execute(s) domains = {} for domain in c.fetchall(): # strip (30) from character varying(30) - attype = re.search(r'([^\(]+)', domain['attype']).group(1) + attype = re.search(r"([^\(]+)", domain["attype"]).group(1) # 'visible' just means whether or not the domain is in a # schema that's on the search path -- or not overridden by # a schema with higher precedence. If it's not visible, # it will be prefixed with the schema-name when it's used. - if domain['visible']: - key = (domain['name'], ) + if domain["visible"]: + key = (domain["name"],) else: - key = (domain['schema'], domain['name']) + key = (domain["schema"], domain["name"]) domains[key] = { - 'attype': attype, - 'nullable': domain['nullable'], - 'default': domain['default'] + "attype": attype, + "nullable": domain["nullable"], + "default": domain["default"], } return domains diff --git a/lib/sqlalchemy/dialects/postgresql/dml.py b/lib/sqlalchemy/dialects/postgresql/dml.py index 555a9006c..825f13238 100644 --- a/lib/sqlalchemy/dialects/postgresql/dml.py +++ b/lib/sqlalchemy/dialects/postgresql/dml.py @@ -14,7 +14,7 @@ from ...sql.base import _generative from ... import util from . import ext -__all__ = ('Insert', 'insert') +__all__ = ("Insert", "insert") class Insert(StandardInsert): @@ -40,13 +40,17 @@ class Insert(StandardInsert): to use :attr:`.Insert.excluded` """ - return alias(self.table, name='excluded').columns + return alias(self.table, name="excluded").columns @_generative def on_conflict_do_update( - self, - constraint=None, index_elements=None, - index_where=None, set_=None, where=None): + self, + constraint=None, + index_elements=None, + index_where=None, + set_=None, + where=None, + ): """ Specifies a DO UPDATE SET action for ON CONFLICT clause. @@ -96,13 +100,14 @@ class Insert(StandardInsert): """ self._post_values_clause = OnConflictDoUpdate( - constraint, index_elements, index_where, set_, where) + constraint, index_elements, index_where, set_, where + ) return self @_generative def on_conflict_do_nothing( - self, - constraint=None, index_elements=None, index_where=None): + self, constraint=None, index_elements=None, index_where=None + ): """ Specifies a DO NOTHING action for ON CONFLICT clause. @@ -130,30 +135,29 @@ class Insert(StandardInsert): """ self._post_values_clause = OnConflictDoNothing( - constraint, index_elements, index_where) + constraint, index_elements, index_where + ) return self -insert = public_factory(Insert, '.dialects.postgresql.insert') + +insert = public_factory(Insert, ".dialects.postgresql.insert") class OnConflictClause(ClauseElement): - def __init__( - self, - constraint=None, - index_elements=None, - index_where=None): + def __init__(self, constraint=None, index_elements=None, index_where=None): if constraint is not None: - if not isinstance(constraint, util.string_types) and \ - isinstance(constraint, ( - schema.Index, schema.Constraint, - ext.ExcludeConstraint)): - constraint = getattr(constraint, 'name') or constraint + if not isinstance(constraint, util.string_types) and isinstance( + constraint, + (schema.Index, schema.Constraint, ext.ExcludeConstraint), + ): + constraint = getattr(constraint, "name") or constraint if constraint is not None: if index_elements is not None: raise ValueError( - "'constraint' and 'index_elements' are mutually exclusive") + "'constraint' and 'index_elements' are mutually exclusive" + ) if isinstance(constraint, util.string_types): self.constraint_target = constraint @@ -161,54 +165,61 @@ class OnConflictClause(ClauseElement): self.inferred_target_whereclause = None elif isinstance(constraint, schema.Index): index_elements = constraint.expressions - index_where = \ - constraint.dialect_options['postgresql'].get("where") + index_where = constraint.dialect_options["postgresql"].get( + "where" + ) elif isinstance(constraint, ext.ExcludeConstraint): index_elements = constraint.columns index_where = constraint.where else: index_elements = constraint.columns - index_where = \ - constraint.dialect_options['postgresql'].get("where") + index_where = constraint.dialect_options["postgresql"].get( + "where" + ) if index_elements is not None: self.constraint_target = None self.inferred_target_elements = index_elements self.inferred_target_whereclause = index_where elif constraint is None: - self.constraint_target = self.inferred_target_elements = \ - self.inferred_target_whereclause = None + self.constraint_target = ( + self.inferred_target_elements + ) = self.inferred_target_whereclause = None class OnConflictDoNothing(OnConflictClause): - __visit_name__ = 'on_conflict_do_nothing' + __visit_name__ = "on_conflict_do_nothing" class OnConflictDoUpdate(OnConflictClause): - __visit_name__ = 'on_conflict_do_update' + __visit_name__ = "on_conflict_do_update" def __init__( - self, - constraint=None, - index_elements=None, - index_where=None, - set_=None, - where=None): + self, + constraint=None, + index_elements=None, + index_where=None, + set_=None, + where=None, + ): super(OnConflictDoUpdate, self).__init__( constraint=constraint, index_elements=index_elements, - index_where=index_where) + index_where=index_where, + ) - if self.inferred_target_elements is None and \ - self.constraint_target is None: + if ( + self.inferred_target_elements is None + and self.constraint_target is None + ): raise ValueError( "Either constraint or index_elements, " - "but not both, must be specified unless DO NOTHING") + "but not both, must be specified unless DO NOTHING" + ) - if (not isinstance(set_, dict) or not set_): + if not isinstance(set_, dict) or not set_: raise ValueError("set parameter must be a non-empty dictionary") self.update_values_to_set = [ - (key, value) - for key, value in set_.items() + (key, value) for key, value in set_.items() ] self.update_whereclause = where diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py index a588eafdd..da0c6250c 100644 --- a/lib/sqlalchemy/dialects/postgresql/ext.py +++ b/lib/sqlalchemy/dialects/postgresql/ext.py @@ -47,7 +47,7 @@ class aggregate_order_by(expression.ColumnElement): """ - __visit_name__ = 'aggregate_order_by' + __visit_name__ = "aggregate_order_by" def __init__(self, target, *order_by): self.target = elements._literal_as_binds(target) @@ -59,8 +59,8 @@ class aggregate_order_by(expression.ColumnElement): self.order_by = elements._literal_as_binds(order_by[0]) else: self.order_by = elements.ClauseList( - *order_by, - _literal_as_text=elements._literal_as_binds) + *order_by, _literal_as_text=elements._literal_as_binds + ) def self_group(self, against=None): return self @@ -87,7 +87,7 @@ class ExcludeConstraint(ColumnCollectionConstraint): static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE """ - __visit_name__ = 'exclude_constraint' + __visit_name__ = "exclude_constraint" where = None @@ -173,8 +173,7 @@ static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE expressions, operators = zip(*elements) for (expr, column, strname, add_element), operator in zip( - self._extract_col_expression_collection(expressions), - operators + self._extract_col_expression_collection(expressions), operators ): if add_element is not None: columns.append(add_element) @@ -187,32 +186,31 @@ static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE expr = expression._literal_as_text(expr) - render_exprs.append( - (expr, name, operator) - ) + render_exprs.append((expr, name, operator)) self._render_exprs = render_exprs ColumnCollectionConstraint.__init__( self, *columns, - name=kw.get('name'), - deferrable=kw.get('deferrable'), - initially=kw.get('initially') + name=kw.get("name"), + deferrable=kw.get("deferrable"), + initially=kw.get("initially") ) - self.using = kw.get('using', 'gist') - where = kw.get('where') + self.using = kw.get("using", "gist") + where = kw.get("where") if where is not None: self.where = expression._literal_as_text(where) def copy(self, **kw): - elements = [(col, self.operators[col]) - for col in self.columns.keys()] - c = self.__class__(*elements, - name=self.name, - deferrable=self.deferrable, - initially=self.initially, - where=self.where, - using=self.using) + elements = [(col, self.operators[col]) for col in self.columns.keys()] + c = self.__class__( + *elements, + name=self.name, + deferrable=self.deferrable, + initially=self.initially, + where=self.where, + using=self.using + ) c.dispatch._update(self.dispatch) return c @@ -226,5 +224,5 @@ def array_agg(*arg, **kw): .. versionadded:: 1.1 """ - kw['_default_array_type'] = ARRAY + kw["_default_array_type"] = ARRAY return functions.func.array_agg(*arg, **kw) diff --git a/lib/sqlalchemy/dialects/postgresql/hstore.py b/lib/sqlalchemy/dialects/postgresql/hstore.py index b6c9e7124..e4bac692a 100644 --- a/lib/sqlalchemy/dialects/postgresql/hstore.py +++ b/lib/sqlalchemy/dialects/postgresql/hstore.py @@ -14,38 +14,50 @@ from ...sql import functions as sqlfunc from ...sql import operators from ... import util -__all__ = ('HSTORE', 'hstore') +__all__ = ("HSTORE", "hstore") idx_precedence = operators._PRECEDENCE[operators.json_getitem_op] GETITEM = operators.custom_op( - "->", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "->", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) HAS_KEY = operators.custom_op( - "?", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "?", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) HAS_ALL = operators.custom_op( - "?&", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "?&", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) HAS_ANY = operators.custom_op( - "?|", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "?|", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) CONTAINS = operators.custom_op( - "@>", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "@>", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) CONTAINED_BY = operators.custom_op( - "<@", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "<@", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) @@ -122,7 +134,7 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): """ - __visit_name__ = 'HSTORE' + __visit_name__ = "HSTORE" hashable = False text_type = sqltypes.Text() @@ -139,7 +151,8 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): self.text_type = text_type class Comparator( - sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator): + sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator + ): """Define comparison operations for :class:`.HSTORE`.""" def has_key(self, other): @@ -169,7 +182,8 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): keys of the argument jsonb expression. """ return self.operate( - CONTAINED_BY, other, result_type=sqltypes.Boolean) + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) def _setup_getitem(self, index): return GETITEM, index, self.type.text_type @@ -223,12 +237,15 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): return _serialize_hstore(value).encode(encoding) else: return value + else: + def process(value): if isinstance(value, dict): return _serialize_hstore(value) else: return value + return process def result_processor(self, dialect, coltype): @@ -240,16 +257,19 @@ class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine): return _parse_hstore(value.decode(encoding)) else: return value + else: + def process(value): if value is not None: return _parse_hstore(value) else: return value + return process -ischema_names['hstore'] = HSTORE +ischema_names["hstore"] = HSTORE class hstore(sqlfunc.GenericFunction): @@ -279,43 +299,44 @@ class hstore(sqlfunc.GenericFunction): :class:`.HSTORE` - the PostgreSQL ``HSTORE`` datatype. """ + type = HSTORE - name = 'hstore' + name = "hstore" class _HStoreDefinedFunction(sqlfunc.GenericFunction): type = sqltypes.Boolean - name = 'defined' + name = "defined" class _HStoreDeleteFunction(sqlfunc.GenericFunction): type = HSTORE - name = 'delete' + name = "delete" class _HStoreSliceFunction(sqlfunc.GenericFunction): type = HSTORE - name = 'slice' + name = "slice" class _HStoreKeysFunction(sqlfunc.GenericFunction): type = ARRAY(sqltypes.Text) - name = 'akeys' + name = "akeys" class _HStoreValsFunction(sqlfunc.GenericFunction): type = ARRAY(sqltypes.Text) - name = 'avals' + name = "avals" class _HStoreArrayFunction(sqlfunc.GenericFunction): type = ARRAY(sqltypes.Text) - name = 'hstore_to_array' + name = "hstore_to_array" class _HStoreMatrixFunction(sqlfunc.GenericFunction): type = ARRAY(sqltypes.Text) - name = 'hstore_to_matrix' + name = "hstore_to_matrix" # @@ -326,7 +347,8 @@ class _HStoreMatrixFunction(sqlfunc.GenericFunction): # My best guess at the parsing rules of hstore literals, since no formal # grammar is given. This is mostly reverse engineered from PG's input parser # behavior. -HSTORE_PAIR_RE = re.compile(r""" +HSTORE_PAIR_RE = re.compile( + r""" ( "(?P<key> (\\ . | [^"])* )" # Quoted key ) @@ -335,11 +357,16 @@ HSTORE_PAIR_RE = re.compile(r""" (?P<value_null> NULL ) # NULL value | "(?P<value> (\\ . | [^"])* )" # Quoted value ) -""", re.VERBOSE) +""", + re.VERBOSE, +) -HSTORE_DELIMITER_RE = re.compile(r""" +HSTORE_DELIMITER_RE = re.compile( + r""" [ ]* , [ ]* -""", re.VERBOSE) +""", + re.VERBOSE, +) def _parse_error(hstore_str, pos): @@ -348,16 +375,19 @@ def _parse_error(hstore_str, pos): ctx = 20 hslen = len(hstore_str) - parsed_tail = hstore_str[max(pos - ctx - 1, 0):min(pos, hslen)] - residual = hstore_str[min(pos, hslen):min(pos + ctx + 1, hslen)] + parsed_tail = hstore_str[max(pos - ctx - 1, 0) : min(pos, hslen)] + residual = hstore_str[min(pos, hslen) : min(pos + ctx + 1, hslen)] if len(parsed_tail) > ctx: - parsed_tail = '[...]' + parsed_tail[1:] + parsed_tail = "[...]" + parsed_tail[1:] if len(residual) > ctx: - residual = residual[:-1] + '[...]' + residual = residual[:-1] + "[...]" return "After %r, could not parse residual at position %d: %r" % ( - parsed_tail, pos, residual) + parsed_tail, + pos, + residual, + ) def _parse_hstore(hstore_str): @@ -377,13 +407,15 @@ def _parse_hstore(hstore_str): pair_match = HSTORE_PAIR_RE.match(hstore_str) while pair_match is not None: - key = pair_match.group('key').replace(r'\"', '"').replace( - "\\\\", "\\") - if pair_match.group('value_null'): + key = pair_match.group("key").replace(r"\"", '"').replace("\\\\", "\\") + if pair_match.group("value_null"): value = None else: - value = pair_match.group('value').replace( - r'\"', '"').replace("\\\\", "\\") + value = ( + pair_match.group("value") + .replace(r"\"", '"') + .replace("\\\\", "\\") + ) result[key] = value pos += pair_match.end() @@ -405,16 +437,17 @@ def _serialize_hstore(val): both be strings (except None for values). """ + def esc(s, position): - if position == 'value' and s is None: - return 'NULL' + if position == "value" and s is None: + return "NULL" elif isinstance(s, util.string_types): - return '"%s"' % s.replace("\\", "\\\\").replace('"', r'\"') + return '"%s"' % s.replace("\\", "\\\\").replace('"', r"\"") else: - raise ValueError("%r in %s position is not a string." % - (s, position)) - - return ', '.join('%s=>%s' % (esc(k, 'key'), esc(v, 'value')) - for k, v in val.items()) - + raise ValueError( + "%r in %s position is not a string." % (s, position) + ) + return ", ".join( + "%s=>%s" % (esc(k, "key"), esc(v, "value")) for k, v in val.items() + ) diff --git a/lib/sqlalchemy/dialects/postgresql/json.py b/lib/sqlalchemy/dialects/postgresql/json.py index e9256daf3..f9421de37 100644 --- a/lib/sqlalchemy/dialects/postgresql/json.py +++ b/lib/sqlalchemy/dialects/postgresql/json.py @@ -12,44 +12,58 @@ from ...sql import operators from ...sql import elements from ... import util -__all__ = ('JSON', 'JSONB') +__all__ = ("JSON", "JSONB") idx_precedence = operators._PRECEDENCE[operators.json_getitem_op] ASTEXT = operators.custom_op( - "->>", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "->>", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) JSONPATH_ASTEXT = operators.custom_op( - "#>>", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "#>>", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) HAS_KEY = operators.custom_op( - "?", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "?", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) HAS_ALL = operators.custom_op( - "?&", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "?&", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) HAS_ANY = operators.custom_op( - "?|", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "?|", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) CONTAINS = operators.custom_op( - "@>", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "@>", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) CONTAINED_BY = operators.custom_op( - "<@", precedence=idx_precedence, natural_self_precedent=True, - eager_grouping=True + "<@", + precedence=idx_precedence, + natural_self_precedent=True, + eager_grouping=True, ) @@ -59,7 +73,7 @@ class JSONPathType(sqltypes.JSON.JSONPathType): def process(value): assert isinstance(value, util.collections_abc.Sequence) - tokens = [util.text_type(elem)for elem in value] + tokens = [util.text_type(elem) for elem in value] value = "{%s}" % (", ".join(tokens)) if super_proc: value = super_proc(value) @@ -72,7 +86,7 @@ class JSONPathType(sqltypes.JSON.JSONPathType): def process(value): assert isinstance(value, util.collections_abc.Sequence) - tokens = [util.text_type(elem)for elem in value] + tokens = [util.text_type(elem) for elem in value] value = "{%s}" % (", ".join(tokens)) if super_proc: value = super_proc(value) @@ -80,6 +94,7 @@ class JSONPathType(sqltypes.JSON.JSONPathType): return process + colspecs[sqltypes.JSON.JSONPathType] = JSONPathType @@ -203,16 +218,19 @@ class JSON(sqltypes.JSON): if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType): return self.expr.left.operate( JSONPATH_ASTEXT, - self.expr.right, result_type=self.type.astext_type) + self.expr.right, + result_type=self.type.astext_type, + ) else: return self.expr.left.operate( - ASTEXT, self.expr.right, result_type=self.type.astext_type) + ASTEXT, self.expr.right, result_type=self.type.astext_type + ) comparator_factory = Comparator colspecs[sqltypes.JSON] = JSON -ischema_names['json'] = JSON +ischema_names["json"] = JSON class JSONB(JSON): @@ -259,7 +277,7 @@ class JSONB(JSON): """ - __visit_name__ = 'JSONB' + __visit_name__ = "JSONB" class Comparator(JSON.Comparator): """Define comparison operations for :class:`.JSON`.""" @@ -291,8 +309,10 @@ class JSONB(JSON): keys of the argument jsonb expression. """ return self.operate( - CONTAINED_BY, other, result_type=sqltypes.Boolean) + CONTAINED_BY, other, result_type=sqltypes.Boolean + ) comparator_factory = Comparator -ischema_names['jsonb'] = JSONB + +ischema_names["jsonb"] = JSONB diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index 80929b808..fef09e0eb 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -69,8 +69,15 @@ import decimal from ... import processors from ... import types as sqltypes from .base import ( - PGDialect, PGCompiler, PGIdentifierPreparer, PGExecutionContext, - _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID) + PGDialect, + PGCompiler, + PGIdentifierPreparer, + PGExecutionContext, + _DECIMAL_TYPES, + _FLOAT_TYPES, + _INT_TYPES, + UUID, +) import re from sqlalchemy.dialects.postgresql.json import JSON from ...sql.elements import quoted_name @@ -86,13 +93,15 @@ class _PGNumeric(sqltypes.Numeric): if self.asdecimal: if coltype in _FLOAT_TYPES: return processors.to_decimal_processor_factory( - decimal.Decimal, self._effective_decimal_return_scale) + decimal.Decimal, self._effective_decimal_return_scale + ) elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: # pg8000 returns Decimal natively for 1700 return None else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype + ) else: if coltype in _FLOAT_TYPES: # pg8000 returns float natively for 701 @@ -101,7 +110,8 @@ class _PGNumeric(sqltypes.Numeric): return processors.to_float else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype + ) class _PGNumericNoBind(_PGNumeric): @@ -110,7 +120,6 @@ class _PGNumericNoBind(_PGNumeric): class _PGJSON(JSON): - def result_processor(self, dialect, coltype): if dialect._dbapi_version > (1, 10, 1): return None # Has native JSON @@ -121,18 +130,22 @@ class _PGJSON(JSON): class _PGUUID(UUID): def bind_processor(self, dialect): if not self.as_uuid: + def process(value): if value is not None: value = _python_UUID(value) return value + return process def result_processor(self, dialect, coltype): if not self.as_uuid: + def process(value): if value is not None: value = str(value) return value + return process @@ -142,36 +155,41 @@ class PGExecutionContext_pg8000(PGExecutionContext): class PGCompiler_pg8000(PGCompiler): def visit_mod_binary(self, binary, operator, **kw): - return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) def post_process_text(self, text): - if '%%' in text: - util.warn("The SQLAlchemy postgresql dialect " - "now automatically escapes '%' in text() " - "expressions to '%%'.") - return text.replace('%', '%%') + if "%%" in text: + util.warn( + "The SQLAlchemy postgresql dialect " + "now automatically escapes '%' in text() " + "expressions to '%%'." + ) + return text.replace("%", "%%") class PGIdentifierPreparer_pg8000(PGIdentifierPreparer): def _escape_identifier(self, value): value = value.replace(self.escape_quote, self.escape_to_quote) - return value.replace('%', '%%') + return value.replace("%", "%%") class PGDialect_pg8000(PGDialect): - driver = 'pg8000' + driver = "pg8000" supports_unicode_statements = True supports_unicode_binds = True - default_paramstyle = 'format' + default_paramstyle = "format" supports_sane_multi_rowcount = True execution_ctx_cls = PGExecutionContext_pg8000 statement_compiler = PGCompiler_pg8000 preparer = PGIdentifierPreparer_pg8000 - description_encoding = 'use_encoding' + description_encoding = "use_encoding" colspecs = util.update_copy( PGDialect.colspecs, @@ -180,8 +198,8 @@ class PGDialect_pg8000(PGDialect): sqltypes.Float: _PGNumeric, JSON: _PGJSON, sqltypes.JSON: _PGJSON, - UUID: _PGUUID - } + UUID: _PGUUID, + }, ) def __init__(self, client_encoding=None, **kwargs): @@ -194,22 +212,26 @@ class PGDialect_pg8000(PGDialect): @util.memoized_property def _dbapi_version(self): - if self.dbapi and hasattr(self.dbapi, '__version__'): + if self.dbapi and hasattr(self.dbapi, "__version__"): return tuple( [ - int(x) for x in re.findall( - r'(\d+)(?:[-\.]?|$)', self.dbapi.__version__)]) + int(x) + for x in re.findall( + r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__ + ) + ] + ) else: return (99, 99, 99) @classmethod def dbapi(cls): - return __import__('pg8000') + return __import__("pg8000") def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if 'port' in opts: - opts['port'] = int(opts['port']) + opts = url.translate_connect_args(username="user") + if "port" in opts: + opts["port"] = int(opts["port"]) opts.update(url.query) return ([], opts) @@ -217,32 +239,33 @@ class PGDialect_pg8000(PGDialect): return "connection is closed" in str(e) def set_isolation_level(self, connection, level): - level = level.replace('_', ' ') + level = level.replace("_", " ") # adjust for ConnectionFairy possibly being present - if hasattr(connection, 'connection'): + if hasattr(connection, "connection"): connection = connection.connection - if level == 'AUTOCOMMIT': + if level == "AUTOCOMMIT": connection.autocommit = True elif level in self._isolation_lookup: connection.autocommit = False cursor = connection.cursor() cursor.execute( "SET SESSION CHARACTERISTICS AS TRANSACTION " - "ISOLATION LEVEL %s" % level) + "ISOLATION LEVEL %s" % level + ) cursor.execute("COMMIT") cursor.close() else: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s or AUTOCOMMIT" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s or AUTOCOMMIT" + % (level, self.name, ", ".join(self._isolation_lookup)) ) def set_client_encoding(self, connection, client_encoding): # adjust for ConnectionFairy possibly being present - if hasattr(connection, 'connection'): + if hasattr(connection, "connection"): connection = connection.connection cursor = connection.cursor() @@ -251,18 +274,20 @@ class PGDialect_pg8000(PGDialect): cursor.close() def do_begin_twophase(self, connection, xid): - connection.connection.tpc_begin((0, xid, '')) + connection.connection.tpc_begin((0, xid, "")) def do_prepare_twophase(self, connection, xid): connection.connection.tpc_prepare() def do_rollback_twophase( - self, connection, xid, is_prepared=True, recover=False): - connection.connection.tpc_rollback((0, xid, '')) + self, connection, xid, is_prepared=True, recover=False + ): + connection.connection.tpc_rollback((0, xid, "")) def do_commit_twophase( - self, connection, xid, is_prepared=True, recover=False): - connection.connection.tpc_commit((0, xid, '')) + self, connection, xid, is_prepared=True, recover=False + ): + connection.connection.tpc_commit((0, xid, "")) def do_recover_twophase(self, connection): return [row[1] for row in connection.connection.tpc_recover()] @@ -272,24 +297,32 @@ class PGDialect_pg8000(PGDialect): def on_connect(conn): conn.py_types[quoted_name] = conn.py_types[util.text_type] + fns.append(on_connect) if self.client_encoding is not None: + def on_connect(conn): self.set_client_encoding(conn, self.client_encoding) + fns.append(on_connect) if self.isolation_level is not None: + def on_connect(conn): self.set_isolation_level(conn, self.isolation_level) + fns.append(on_connect) if len(fns) > 0: + def on_connect(conn): for fn in fns: fn(conn) + return on_connect else: return None + dialect = PGDialect_pg8000 diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2.py b/lib/sqlalchemy/dialects/postgresql/psycopg2.py index baa0e00d5..2c27c6919 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2.py @@ -353,10 +353,17 @@ from ... import processors from ...engine import result as _result from ...sql import expression from ... import types as sqltypes -from .base import PGDialect, PGCompiler, \ - PGIdentifierPreparer, PGExecutionContext, \ - ENUM, _DECIMAL_TYPES, _FLOAT_TYPES,\ - _INT_TYPES, UUID +from .base import ( + PGDialect, + PGCompiler, + PGIdentifierPreparer, + PGExecutionContext, + ENUM, + _DECIMAL_TYPES, + _FLOAT_TYPES, + _INT_TYPES, + UUID, +) from .hstore import HSTORE from .json import JSON, JSONB @@ -366,7 +373,7 @@ except ImportError: _python_UUID = None -logger = logging.getLogger('sqlalchemy.dialects.postgresql') +logger = logging.getLogger("sqlalchemy.dialects.postgresql") class _PGNumeric(sqltypes.Numeric): @@ -377,14 +384,15 @@ class _PGNumeric(sqltypes.Numeric): if self.asdecimal: if coltype in _FLOAT_TYPES: return processors.to_decimal_processor_factory( - decimal.Decimal, - self._effective_decimal_return_scale) + decimal.Decimal, self._effective_decimal_return_scale + ) elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: # pg8000 returns Decimal natively for 1700 return None else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype + ) else: if coltype in _FLOAT_TYPES: # pg8000 returns float natively for 701 @@ -393,7 +401,8 @@ class _PGNumeric(sqltypes.Numeric): return processors.to_float else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype + ) class _PGEnum(ENUM): @@ -421,7 +430,6 @@ class _PGHStore(HSTORE): class _PGJSON(JSON): - def result_processor(self, dialect, coltype): if dialect._has_native_json: return None @@ -430,7 +438,6 @@ class _PGJSON(JSON): class _PGJSONB(JSONB): - def result_processor(self, dialect, coltype): if dialect._has_native_jsonb: return None @@ -447,14 +454,17 @@ class _PGUUID(UUID): if value is not None: value = _python_UUID(value) return value + return process def result_processor(self, dialect, coltype): if not self.as_uuid and dialect.use_native_uuid: + def process(value): if value is not None: value = str(value) return value + return process @@ -465,8 +475,7 @@ class PGExecutionContext_psycopg2(PGExecutionContext): def create_server_side_cursor(self): # use server-side cursors: # http://lists.initd.org/pipermail/psycopg/2007-January/005251.html - ident = "c_%s_%s" % (hex(id(self))[2:], - hex(_server_side_id())[2:]) + ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:]) return self._dbapi_connection.cursor(ident) def get_result_proxy(self): @@ -497,13 +506,13 @@ class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer): class PGDialect_psycopg2(PGDialect): - driver = 'psycopg2' + driver = "psycopg2" if util.py2k: supports_unicode_statements = False supports_server_side_cursors = True - default_paramstyle = 'pyformat' + default_paramstyle = "pyformat" # set to true based on psycopg2 version supports_sane_multi_rowcount = False execution_ctx_cls = PGExecutionContext_psycopg2 @@ -516,16 +525,16 @@ class PGDialect_psycopg2(PGDialect): native_jsonb=(2, 5, 4), sane_multi_rowcount=(2, 0, 9), array_oid=(2, 4, 3), - hstore_adapter=(2, 4) + hstore_adapter=(2, 4), ) _has_native_hstore = False _has_native_json = False _has_native_jsonb = False - engine_config_types = PGDialect.engine_config_types.union([ - ('use_native_unicode', util.asbool), - ]) + engine_config_types = PGDialect.engine_config_types.union( + [("use_native_unicode", util.asbool)] + ) colspecs = util.update_copy( PGDialect.colspecs, @@ -537,15 +546,20 @@ class PGDialect_psycopg2(PGDialect): JSON: _PGJSON, sqltypes.JSON: _PGJSON, JSONB: _PGJSONB, - UUID: _PGUUID - } + UUID: _PGUUID, + }, ) - def __init__(self, server_side_cursors=False, use_native_unicode=True, - client_encoding=None, - use_native_hstore=True, use_native_uuid=True, - use_batch_mode=False, - **kwargs): + def __init__( + self, + server_side_cursors=False, + use_native_unicode=True, + client_encoding=None, + use_native_hstore=True, + use_native_uuid=True, + use_batch_mode=False, + **kwargs + ): PGDialect.__init__(self, **kwargs) self.server_side_cursors = server_side_cursors self.use_native_unicode = use_native_unicode @@ -554,65 +568,70 @@ class PGDialect_psycopg2(PGDialect): self.supports_unicode_binds = use_native_unicode self.client_encoding = client_encoding self.psycopg2_batch_mode = use_batch_mode - if self.dbapi and hasattr(self.dbapi, '__version__'): - m = re.match(r'(\d+)\.(\d+)(?:\.(\d+))?', - self.dbapi.__version__) + if self.dbapi and hasattr(self.dbapi, "__version__"): + m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__) if m: self.psycopg2_version = tuple( - int(x) - for x in m.group(1, 2, 3) - if x is not None) + int(x) for x in m.group(1, 2, 3) if x is not None + ) def initialize(self, connection): super(PGDialect_psycopg2, self).initialize(connection) - self._has_native_hstore = self.use_native_hstore and \ - self._hstore_oids(connection.connection) \ - is not None - self._has_native_json = \ - self.psycopg2_version >= self.FEATURE_VERSION_MAP['native_json'] - self._has_native_jsonb = \ - self.psycopg2_version >= self.FEATURE_VERSION_MAP['native_jsonb'] + self._has_native_hstore = ( + self.use_native_hstore + and self._hstore_oids(connection.connection) is not None + ) + self._has_native_json = ( + self.psycopg2_version >= self.FEATURE_VERSION_MAP["native_json"] + ) + self._has_native_jsonb = ( + self.psycopg2_version >= self.FEATURE_VERSION_MAP["native_jsonb"] + ) # http://initd.org/psycopg/docs/news.html#what-s-new-in-psycopg-2-0-9 - self.supports_sane_multi_rowcount = \ - self.psycopg2_version >= \ - self.FEATURE_VERSION_MAP['sane_multi_rowcount'] and \ - not self.psycopg2_batch_mode + self.supports_sane_multi_rowcount = ( + self.psycopg2_version + >= self.FEATURE_VERSION_MAP["sane_multi_rowcount"] + and not self.psycopg2_batch_mode + ) @classmethod def dbapi(cls): import psycopg2 + return psycopg2 @classmethod def _psycopg2_extensions(cls): from psycopg2 import extensions + return extensions @classmethod def _psycopg2_extras(cls): from psycopg2 import extras + return extras @util.memoized_property def _isolation_lookup(self): extensions = self._psycopg2_extensions() return { - 'AUTOCOMMIT': extensions.ISOLATION_LEVEL_AUTOCOMMIT, - 'READ COMMITTED': extensions.ISOLATION_LEVEL_READ_COMMITTED, - 'READ UNCOMMITTED': extensions.ISOLATION_LEVEL_READ_UNCOMMITTED, - 'REPEATABLE READ': extensions.ISOLATION_LEVEL_REPEATABLE_READ, - 'SERIALIZABLE': extensions.ISOLATION_LEVEL_SERIALIZABLE + "AUTOCOMMIT": extensions.ISOLATION_LEVEL_AUTOCOMMIT, + "READ COMMITTED": extensions.ISOLATION_LEVEL_READ_COMMITTED, + "READ UNCOMMITTED": extensions.ISOLATION_LEVEL_READ_UNCOMMITTED, + "REPEATABLE READ": extensions.ISOLATION_LEVEL_REPEATABLE_READ, + "SERIALIZABLE": extensions.ISOLATION_LEVEL_SERIALIZABLE, } def set_isolation_level(self, connection, level): try: - level = self._isolation_lookup[level.replace('_', ' ')] + level = self._isolation_lookup[level.replace("_", " ")] except KeyError: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) ) connection.set_isolation_level(level) @@ -623,54 +642,72 @@ class PGDialect_psycopg2(PGDialect): fns = [] if self.client_encoding is not None: + def on_connect(conn): conn.set_client_encoding(self.client_encoding) + fns.append(on_connect) if self.isolation_level is not None: + def on_connect(conn): self.set_isolation_level(conn, self.isolation_level) + fns.append(on_connect) if self.dbapi and self.use_native_uuid: + def on_connect(conn): extras.register_uuid(None, conn) + fns.append(on_connect) if self.dbapi and self.use_native_unicode: + def on_connect(conn): extensions.register_type(extensions.UNICODE, conn) extensions.register_type(extensions.UNICODEARRAY, conn) + fns.append(on_connect) if self.dbapi and self.use_native_hstore: + def on_connect(conn): hstore_oids = self._hstore_oids(conn) if hstore_oids is not None: oid, array_oid = hstore_oids - kw = {'oid': oid} + kw = {"oid": oid} if util.py2k: - kw['unicode'] = True - if self.psycopg2_version >= \ - self.FEATURE_VERSION_MAP['array_oid']: - kw['array_oid'] = array_oid + kw["unicode"] = True + if ( + self.psycopg2_version + >= self.FEATURE_VERSION_MAP["array_oid"] + ): + kw["array_oid"] = array_oid extras.register_hstore(conn, **kw) + fns.append(on_connect) if self.dbapi and self._json_deserializer: + def on_connect(conn): if self._has_native_json: extras.register_default_json( - conn, loads=self._json_deserializer) + conn, loads=self._json_deserializer + ) if self._has_native_jsonb: extras.register_default_jsonb( - conn, loads=self._json_deserializer) + conn, loads=self._json_deserializer + ) + fns.append(on_connect) if fns: + def on_connect(conn): for fn in fns: fn(conn) + return on_connect else: return None @@ -684,7 +721,7 @@ class PGDialect_psycopg2(PGDialect): @util.memoized_instancemethod def _hstore_oids(self, conn): - if self.psycopg2_version >= self.FEATURE_VERSION_MAP['hstore_adapter']: + if self.psycopg2_version >= self.FEATURE_VERSION_MAP["hstore_adapter"]: extras = self._psycopg2_extras() oids = extras.HstoreAdapter.get_oids(conn) if oids is not None and oids[0]: @@ -692,9 +729,9 @@ class PGDialect_psycopg2(PGDialect): return None def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if 'port' in opts: - opts['port'] = int(opts['port']) + opts = url.translate_connect_args(username="user") + if "port" in opts: + opts["port"] = int(opts["port"]) opts.update(url.query) return ([], opts) @@ -704,7 +741,7 @@ class PGDialect_psycopg2(PGDialect): # present on old psycopg2 versions. Also, # this flag doesn't actually help in a lot of disconnect # situations, so don't rely on it. - if getattr(connection, 'closed', False): + if getattr(connection, "closed", False): return True # checks based on strings. in the case that .closed @@ -713,28 +750,29 @@ class PGDialect_psycopg2(PGDialect): for msg in [ # these error messages from libpq: interfaces/libpq/fe-misc.c # and interfaces/libpq/fe-secure.c. - 'terminating connection', - 'closed the connection', - 'connection not open', - 'could not receive data from server', - 'could not send data to server', + "terminating connection", + "closed the connection", + "connection not open", + "could not receive data from server", + "could not send data to server", # psycopg2 client errors, psycopg2/conenction.h, # psycopg2/cursor.h - 'connection already closed', - 'cursor already closed', + "connection already closed", + "cursor already closed", # not sure where this path is originally from, it may # be obsolete. It really says "losed", not "closed". - 'losed the connection unexpectedly', + "losed the connection unexpectedly", # these can occur in newer SSL - 'connection has been closed unexpectedly', - 'SSL SYSCALL error: Bad file descriptor', - 'SSL SYSCALL error: EOF detected', - 'SSL error: decryption failed or bad record mac', - 'SSL SYSCALL error: Operation timed out', + "connection has been closed unexpectedly", + "SSL SYSCALL error: Bad file descriptor", + "SSL SYSCALL error: EOF detected", + "SSL error: decryption failed or bad record mac", + "SSL SYSCALL error: Operation timed out", ]: idx = str_e.find(msg) if idx >= 0 and '"' not in str_e[:idx]: return True return False + dialect = PGDialect_psycopg2 diff --git a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py index a1141a90e..7343bc973 100644 --- a/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py +++ b/lib/sqlalchemy/dialects/postgresql/psycopg2cffi.py @@ -28,7 +28,7 @@ from .psycopg2 import PGDialect_psycopg2 class PGDialect_psycopg2cffi(PGDialect_psycopg2): - driver = 'psycopg2cffi' + driver = "psycopg2cffi" supports_unicode_statements = True # psycopg2cffi's first release is 2.5.0, but reports @@ -40,21 +40,21 @@ class PGDialect_psycopg2cffi(PGDialect_psycopg2): native_jsonb=(2, 7, 1), sane_multi_rowcount=(2, 4, 4), array_oid=(2, 4, 4), - hstore_adapter=(2, 4, 4) + hstore_adapter=(2, 4, 4), ) @classmethod def dbapi(cls): - return __import__('psycopg2cffi') + return __import__("psycopg2cffi") @classmethod def _psycopg2_extensions(cls): - root = __import__('psycopg2cffi', fromlist=['extensions']) + root = __import__("psycopg2cffi", fromlist=["extensions"]) return root.extensions @classmethod def _psycopg2_extras(cls): - root = __import__('psycopg2cffi', fromlist=['extras']) + root = __import__("psycopg2cffi", fromlist=["extras"]) return root.extras diff --git a/lib/sqlalchemy/dialects/postgresql/pygresql.py b/lib/sqlalchemy/dialects/postgresql/pygresql.py index 304afca44..c7edb8fc3 100644 --- a/lib/sqlalchemy/dialects/postgresql/pygresql.py +++ b/lib/sqlalchemy/dialects/postgresql/pygresql.py @@ -20,14 +20,20 @@ import re from ... import exc, processors, util from ...types import Numeric, JSON as Json from ...sql.elements import Null -from .base import PGDialect, PGCompiler, PGIdentifierPreparer, \ - _DECIMAL_TYPES, _FLOAT_TYPES, _INT_TYPES, UUID +from .base import ( + PGDialect, + PGCompiler, + PGIdentifierPreparer, + _DECIMAL_TYPES, + _FLOAT_TYPES, + _INT_TYPES, + UUID, +) from .hstore import HSTORE from .json import JSON, JSONB class _PGNumeric(Numeric): - def bind_processor(self, dialect): return None @@ -37,14 +43,15 @@ class _PGNumeric(Numeric): if self.asdecimal: if coltype in _FLOAT_TYPES: return processors.to_decimal_processor_factory( - decimal.Decimal, - self._effective_decimal_return_scale) + decimal.Decimal, self._effective_decimal_return_scale + ) elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES: # PyGreSQL returns Decimal natively for 1700 (numeric) return None else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype + ) else: if coltype in _FLOAT_TYPES: # PyGreSQL returns float natively for 701 (float8) @@ -53,19 +60,21 @@ class _PGNumeric(Numeric): return processors.to_float else: raise exc.InvalidRequestError( - "Unknown PG numeric type: %d" % coltype) + "Unknown PG numeric type: %d" % coltype + ) class _PGHStore(HSTORE): - def bind_processor(self, dialect): if not dialect.has_native_hstore: return super(_PGHStore, self).bind_processor(dialect) hstore = dialect.dbapi.Hstore + def process(value): if isinstance(value, dict): return hstore(value) return value + return process def result_processor(self, dialect, coltype): @@ -74,7 +83,6 @@ class _PGHStore(HSTORE): class _PGJSON(JSON): - def bind_processor(self, dialect): if not dialect.has_native_json: return super(_PGJSON, self).bind_processor(dialect) @@ -84,7 +92,8 @@ class _PGJSON(JSON): if value is self.NULL: value = None elif isinstance(value, Null) or ( - value is None and self.none_as_null): + value is None and self.none_as_null + ): return None if value is None or isinstance(value, (dict, list)): return json(value) @@ -98,7 +107,6 @@ class _PGJSON(JSON): class _PGJSONB(JSONB): - def bind_processor(self, dialect): if not dialect.has_native_json: return super(_PGJSONB, self).bind_processor(dialect) @@ -108,7 +116,8 @@ class _PGJSONB(JSONB): if value is self.NULL: value = None elif isinstance(value, Null) or ( - value is None and self.none_as_null): + value is None and self.none_as_null + ): return None if value is None or isinstance(value, (dict, list)): return json(value) @@ -122,7 +131,6 @@ class _PGJSONB(JSONB): class _PGUUID(UUID): - def bind_processor(self, dialect): if not dialect.has_native_uuid: return super(_PGUUID, self).bind_processor(dialect) @@ -145,32 +153,35 @@ class _PGUUID(UUID): if not dialect.has_native_uuid: return super(_PGUUID, self).result_processor(dialect, coltype) if not self.as_uuid: + def process(value): if value is not None: return str(value) + return process class _PGCompiler(PGCompiler): - def visit_mod_binary(self, binary, operator, **kw): - return self.process(binary.left, **kw) + " %% " + \ - self.process(binary.right, **kw) + return ( + self.process(binary.left, **kw) + + " %% " + + self.process(binary.right, **kw) + ) def post_process_text(self, text): - return text.replace('%', '%%') + return text.replace("%", "%%") class _PGIdentifierPreparer(PGIdentifierPreparer): - def _escape_identifier(self, value): value = value.replace(self.escape_quote, self.escape_to_quote) - return value.replace('%', '%%') + return value.replace("%", "%%") class PGDialect_pygresql(PGDialect): - driver = 'pygresql' + driver = "pygresql" statement_compiler = _PGCompiler preparer = _PGIdentifierPreparer @@ -178,6 +189,7 @@ class PGDialect_pygresql(PGDialect): @classmethod def dbapi(cls): import pgdb + return pgdb colspecs = util.update_copy( @@ -189,14 +201,14 @@ class PGDialect_pygresql(PGDialect): JSON: _PGJSON, JSONB: _PGJSONB, UUID: _PGUUID, - } + }, ) def __init__(self, **kwargs): super(PGDialect_pygresql, self).__init__(**kwargs) try: version = self.dbapi.version - m = re.match(r'(\d+)\.(\d+)', version) + m = re.match(r"(\d+)\.(\d+)", version) version = (int(m.group(1)), int(m.group(2))) except (AttributeError, ValueError, TypeError): version = (0, 0) @@ -204,8 +216,10 @@ class PGDialect_pygresql(PGDialect): if version < (5, 0): has_native_hstore = has_native_json = has_native_uuid = False if version != (0, 0): - util.warn("PyGreSQL is only fully supported by SQLAlchemy" - " since version 5.0.") + util.warn( + "PyGreSQL is only fully supported by SQLAlchemy" + " since version 5.0." + ) else: self.supports_unicode_statements = True self.supports_unicode_binds = True @@ -215,10 +229,12 @@ class PGDialect_pygresql(PGDialect): self.has_native_uuid = has_native_uuid def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if 'port' in opts: - opts['host'] = '%s:%s' % ( - opts.get('host', '').rsplit(':', 1)[0], opts.pop('port')) + opts = url.translate_connect_args(username="user") + if "port" in opts: + opts["host"] = "%s:%s" % ( + opts.get("host", "").rsplit(":", 1)[0], + opts.pop("port"), + ) opts.update(url.query) return [], opts diff --git a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py index b633323b4..93bf653a4 100644 --- a/lib/sqlalchemy/dialects/postgresql/pypostgresql.py +++ b/lib/sqlalchemy/dialects/postgresql/pypostgresql.py @@ -37,12 +37,12 @@ class PGExecutionContext_pypostgresql(PGExecutionContext): class PGDialect_pypostgresql(PGDialect): - driver = 'pypostgresql' + driver = "pypostgresql" supports_unicode_statements = True supports_unicode_binds = True description_encoding = None - default_paramstyle = 'pyformat' + default_paramstyle = "pyformat" # requires trunk version to support sane rowcounts # TODO: use dbapi version information to set this flag appropriately @@ -54,22 +54,27 @@ class PGDialect_pypostgresql(PGDialect): PGDialect.colspecs, { sqltypes.Numeric: PGNumeric, - # prevents PGNumeric from being used sqltypes.Float: sqltypes.Float, - } + }, ) @classmethod def dbapi(cls): from postgresql.driver import dbapi20 + return dbapi20 _DBAPI_ERROR_NAMES = [ "Error", - "InterfaceError", "DatabaseError", "DataError", - "OperationalError", "IntegrityError", "InternalError", - "ProgrammingError", "NotSupportedError" + "InterfaceError", + "DatabaseError", + "DataError", + "OperationalError", + "IntegrityError", + "InternalError", + "ProgrammingError", + "NotSupportedError", ] @util.memoized_property @@ -83,15 +88,16 @@ class PGDialect_pypostgresql(PGDialect): ) def create_connect_args(self, url): - opts = url.translate_connect_args(username='user') - if 'port' in opts: - opts['port'] = int(opts['port']) + opts = url.translate_connect_args(username="user") + if "port" in opts: + opts["port"] = int(opts["port"]) else: - opts['port'] = 5432 + opts["port"] = 5432 opts.update(url.query) return ([], opts) def is_disconnect(self, e, connection, cursor): return "connection is closed" in str(e) + dialect = PGDialect_pypostgresql diff --git a/lib/sqlalchemy/dialects/postgresql/ranges.py b/lib/sqlalchemy/dialects/postgresql/ranges.py index eb2d86bbd..62d1275a6 100644 --- a/lib/sqlalchemy/dialects/postgresql/ranges.py +++ b/lib/sqlalchemy/dialects/postgresql/ranges.py @@ -7,7 +7,7 @@ from .base import ischema_names from ... import types as sqltypes -__all__ = ('INT4RANGE', 'INT8RANGE', 'NUMRANGE') +__all__ = ("INT4RANGE", "INT8RANGE", "NUMRANGE") class RangeOperators(object): @@ -34,35 +34,36 @@ class RangeOperators(object): def __ne__(self, other): "Boolean expression. Returns true if two ranges are not equal" if other is None: - return super( - RangeOperators.comparator_factory, self).__ne__(other) + return super(RangeOperators.comparator_factory, self).__ne__( + other + ) else: - return self.expr.op('<>')(other) + return self.expr.op("<>")(other) def contains(self, other, **kw): """Boolean expression. Returns true if the right hand operand, which can be an element or a range, is contained within the column. """ - return self.expr.op('@>')(other) + return self.expr.op("@>")(other) def contained_by(self, other): """Boolean expression. Returns true if the column is contained within the right hand operand. """ - return self.expr.op('<@')(other) + return self.expr.op("<@")(other) def overlaps(self, other): """Boolean expression. Returns true if the column overlaps (has points in common with) the right hand operand. """ - return self.expr.op('&&')(other) + return self.expr.op("&&")(other) def strictly_left_of(self, other): """Boolean expression. Returns true if the column is strictly left of the right hand operand. """ - return self.expr.op('<<')(other) + return self.expr.op("<<")(other) __lshift__ = strictly_left_of @@ -70,7 +71,7 @@ class RangeOperators(object): """Boolean expression. Returns true if the column is strictly right of the right hand operand. """ - return self.expr.op('>>')(other) + return self.expr.op(">>")(other) __rshift__ = strictly_right_of @@ -78,26 +79,26 @@ class RangeOperators(object): """Boolean expression. Returns true if the range in the column does not extend right of the range in the operand. """ - return self.expr.op('&<')(other) + return self.expr.op("&<")(other) def not_extend_left_of(self, other): """Boolean expression. Returns true if the range in the column does not extend left of the range in the operand. """ - return self.expr.op('&>')(other) + return self.expr.op("&>")(other) def adjacent_to(self, other): """Boolean expression. Returns true if the range in the column is adjacent to the range in the operand. """ - return self.expr.op('-|-')(other) + return self.expr.op("-|-")(other) def __add__(self, other): """Range expression. Returns the union of the two ranges. Will raise an exception if the resulting range is not contigous. """ - return self.expr.op('+')(other) + return self.expr.op("+")(other) class INT4RANGE(RangeOperators, sqltypes.TypeEngine): @@ -107,9 +108,10 @@ class INT4RANGE(RangeOperators, sqltypes.TypeEngine): """ - __visit_name__ = 'INT4RANGE' + __visit_name__ = "INT4RANGE" -ischema_names['int4range'] = INT4RANGE + +ischema_names["int4range"] = INT4RANGE class INT8RANGE(RangeOperators, sqltypes.TypeEngine): @@ -119,9 +121,10 @@ class INT8RANGE(RangeOperators, sqltypes.TypeEngine): """ - __visit_name__ = 'INT8RANGE' + __visit_name__ = "INT8RANGE" + -ischema_names['int8range'] = INT8RANGE +ischema_names["int8range"] = INT8RANGE class NUMRANGE(RangeOperators, sqltypes.TypeEngine): @@ -131,9 +134,10 @@ class NUMRANGE(RangeOperators, sqltypes.TypeEngine): """ - __visit_name__ = 'NUMRANGE' + __visit_name__ = "NUMRANGE" + -ischema_names['numrange'] = NUMRANGE +ischema_names["numrange"] = NUMRANGE class DATERANGE(RangeOperators, sqltypes.TypeEngine): @@ -143,9 +147,10 @@ class DATERANGE(RangeOperators, sqltypes.TypeEngine): """ - __visit_name__ = 'DATERANGE' + __visit_name__ = "DATERANGE" -ischema_names['daterange'] = DATERANGE + +ischema_names["daterange"] = DATERANGE class TSRANGE(RangeOperators, sqltypes.TypeEngine): @@ -155,9 +160,10 @@ class TSRANGE(RangeOperators, sqltypes.TypeEngine): """ - __visit_name__ = 'TSRANGE' + __visit_name__ = "TSRANGE" + -ischema_names['tsrange'] = TSRANGE +ischema_names["tsrange"] = TSRANGE class TSTZRANGE(RangeOperators, sqltypes.TypeEngine): @@ -167,6 +173,7 @@ class TSTZRANGE(RangeOperators, sqltypes.TypeEngine): """ - __visit_name__ = 'TSTZRANGE' + __visit_name__ = "TSTZRANGE" + -ischema_names['tstzrange'] = TSTZRANGE +ischema_names["tstzrange"] = TSTZRANGE diff --git a/lib/sqlalchemy/dialects/postgresql/zxjdbc.py b/lib/sqlalchemy/dialects/postgresql/zxjdbc.py index ef6e8f1f9..4d984443a 100644 --- a/lib/sqlalchemy/dialects/postgresql/zxjdbc.py +++ b/lib/sqlalchemy/dialects/postgresql/zxjdbc.py @@ -19,7 +19,6 @@ from .base import PGDialect, PGExecutionContext class PGExecutionContext_zxjdbc(PGExecutionContext): - def create_cursor(self): cursor = self._dbapi_connection.cursor() cursor.datahandler = self.dialect.DataHandler(cursor.datahandler) @@ -27,8 +26,8 @@ class PGExecutionContext_zxjdbc(PGExecutionContext): class PGDialect_zxjdbc(ZxJDBCConnector, PGDialect): - jdbc_db_name = 'postgresql' - jdbc_driver_name = 'org.postgresql.Driver' + jdbc_db_name = "postgresql" + jdbc_driver_name = "org.postgresql.Driver" execution_ctx_cls = PGExecutionContext_zxjdbc @@ -37,10 +36,12 @@ class PGDialect_zxjdbc(ZxJDBCConnector, PGDialect): def __init__(self, *args, **kwargs): super(PGDialect_zxjdbc, self).__init__(*args, **kwargs) from com.ziclix.python.sql.handler import PostgresqlDataHandler + self.DataHandler = PostgresqlDataHandler def _get_server_version_info(self, connection): - parts = connection.connection.dbversion.split('.') + parts = connection.connection.dbversion.split(".") return tuple(int(x) for x in parts) + dialect = PGDialect_zxjdbc diff --git a/lib/sqlalchemy/dialects/sqlite/__init__.py b/lib/sqlalchemy/dialects/sqlite/__init__.py index a73581521..41f017597 100644 --- a/lib/sqlalchemy/dialects/sqlite/__init__.py +++ b/lib/sqlalchemy/dialects/sqlite/__init__.py @@ -8,14 +8,44 @@ from . import base, pysqlite, pysqlcipher # noqa from sqlalchemy.dialects.sqlite.base import ( - BLOB, BOOLEAN, CHAR, DATE, DATETIME, DECIMAL, FLOAT, INTEGER, JSON, REAL, - NUMERIC, SMALLINT, TEXT, TIME, TIMESTAMP, VARCHAR + BLOB, + BOOLEAN, + CHAR, + DATE, + DATETIME, + DECIMAL, + FLOAT, + INTEGER, + JSON, + REAL, + NUMERIC, + SMALLINT, + TEXT, + TIME, + TIMESTAMP, + VARCHAR, ) # default dialect base.dialect = dialect = pysqlite.dialect -__all__ = ('BLOB', 'BOOLEAN', 'CHAR', 'DATE', 'DATETIME', 'DECIMAL', - 'FLOAT', 'INTEGER', 'JSON', 'NUMERIC', 'SMALLINT', 'TEXT', 'TIME', - 'TIMESTAMP', 'VARCHAR', 'REAL', 'dialect') +__all__ = ( + "BLOB", + "BOOLEAN", + "CHAR", + "DATE", + "DATETIME", + "DECIMAL", + "FLOAT", + "INTEGER", + "JSON", + "NUMERIC", + "SMALLINT", + "TEXT", + "TIME", + "TIMESTAMP", + "VARCHAR", + "REAL", + "dialect", +) diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index c487af898..cb9389af1 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -579,9 +579,20 @@ from ... import util from ...engine import default, reflection from ...sql import compiler -from ...types import (BLOB, BOOLEAN, CHAR, DECIMAL, FLOAT, - INTEGER, REAL, NUMERIC, SMALLINT, TEXT, - TIMESTAMP, VARCHAR) +from ...types import ( + BLOB, + BOOLEAN, + CHAR, + DECIMAL, + FLOAT, + INTEGER, + REAL, + NUMERIC, + SMALLINT, + TEXT, + TIMESTAMP, + VARCHAR, +) from .json import JSON, JSONIndexType, JSONPathType @@ -610,10 +621,15 @@ class _DateTimeMixin(object): """ spec = self._storage_format % { - "year": 0, "month": 0, "day": 0, "hour": 0, - "minute": 0, "second": 0, "microsecond": 0 + "year": 0, + "month": 0, + "day": 0, + "hour": 0, + "minute": 0, + "second": 0, + "microsecond": 0, } - return bool(re.search(r'[^0-9]', spec)) + return bool(re.search(r"[^0-9]", spec)) def adapt(self, cls, **kw): if issubclass(cls, _DateTimeMixin): @@ -628,6 +644,7 @@ class _DateTimeMixin(object): def process(value): return "'%s'" % bp(value) + return process @@ -671,13 +688,17 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): ) def __init__(self, *args, **kwargs): - truncate_microseconds = kwargs.pop('truncate_microseconds', False) + truncate_microseconds = kwargs.pop("truncate_microseconds", False) super(DATETIME, self).__init__(*args, **kwargs) if truncate_microseconds: - assert 'storage_format' not in kwargs, "You can specify only "\ + assert "storage_format" not in kwargs, ( + "You can specify only " "one of truncate_microseconds or storage_format." - assert 'regexp' not in kwargs, "You can specify only one of "\ + ) + assert "regexp" not in kwargs, ( + "You can specify only one of " "truncate_microseconds or regexp." + ) self._storage_format = ( "%(year)04d-%(month)02d-%(day)02d " "%(hour)02d:%(minute)02d:%(second)02d" @@ -693,33 +714,37 @@ class DATETIME(_DateTimeMixin, sqltypes.DateTime): return None elif isinstance(value, datetime_datetime): return format % { - 'year': value.year, - 'month': value.month, - 'day': value.day, - 'hour': value.hour, - 'minute': value.minute, - 'second': value.second, - 'microsecond': value.microsecond, + "year": value.year, + "month": value.month, + "day": value.day, + "hour": value.hour, + "minute": value.minute, + "second": value.second, + "microsecond": value.microsecond, } elif isinstance(value, datetime_date): return format % { - 'year': value.year, - 'month': value.month, - 'day': value.day, - 'hour': 0, - 'minute': 0, - 'second': 0, - 'microsecond': 0, + "year": value.year, + "month": value.month, + "day": value.day, + "hour": 0, + "minute": 0, + "second": 0, + "microsecond": 0, } else: - raise TypeError("SQLite DateTime type only accepts Python " - "datetime and date objects as input.") + raise TypeError( + "SQLite DateTime type only accepts Python " + "datetime and date objects as input." + ) + return process def result_processor(self, dialect, coltype): if self._reg: return processors.str_to_datetime_processor_factory( - self._reg, datetime.datetime) + self._reg, datetime.datetime + ) else: return processors.str_to_datetime @@ -768,19 +793,23 @@ class DATE(_DateTimeMixin, sqltypes.Date): return None elif isinstance(value, datetime_date): return format % { - 'year': value.year, - 'month': value.month, - 'day': value.day, + "year": value.year, + "month": value.month, + "day": value.day, } else: - raise TypeError("SQLite Date type only accepts Python " - "date objects as input.") + raise TypeError( + "SQLite Date type only accepts Python " + "date objects as input." + ) + return process def result_processor(self, dialect, coltype): if self._reg: return processors.str_to_datetime_processor_factory( - self._reg, datetime.date) + self._reg, datetime.date + ) else: return processors.str_to_date @@ -820,13 +849,17 @@ class TIME(_DateTimeMixin, sqltypes.Time): _storage_format = "%(hour)02d:%(minute)02d:%(second)02d.%(microsecond)06d" def __init__(self, *args, **kwargs): - truncate_microseconds = kwargs.pop('truncate_microseconds', False) + truncate_microseconds = kwargs.pop("truncate_microseconds", False) super(TIME, self).__init__(*args, **kwargs) if truncate_microseconds: - assert 'storage_format' not in kwargs, "You can specify only "\ + assert "storage_format" not in kwargs, ( + "You can specify only " "one of truncate_microseconds or storage_format." - assert 'regexp' not in kwargs, "You can specify only one of "\ + ) + assert "regexp" not in kwargs, ( + "You can specify only one of " "truncate_microseconds or regexp." + ) self._storage_format = "%(hour)02d:%(minute)02d:%(second)02d" def bind_processor(self, dialect): @@ -838,23 +871,28 @@ class TIME(_DateTimeMixin, sqltypes.Time): return None elif isinstance(value, datetime_time): return format % { - 'hour': value.hour, - 'minute': value.minute, - 'second': value.second, - 'microsecond': value.microsecond, + "hour": value.hour, + "minute": value.minute, + "second": value.second, + "microsecond": value.microsecond, } else: - raise TypeError("SQLite Time type only accepts Python " - "time objects as input.") + raise TypeError( + "SQLite Time type only accepts Python " + "time objects as input." + ) + return process def result_processor(self, dialect, coltype): if self._reg: return processors.str_to_datetime_processor_factory( - self._reg, datetime.time) + self._reg, datetime.time + ) else: return processors.str_to_time + colspecs = { sqltypes.Date: DATE, sqltypes.DateTime: DATETIME, @@ -865,31 +903,31 @@ colspecs = { } ischema_names = { - 'BIGINT': sqltypes.BIGINT, - 'BLOB': sqltypes.BLOB, - 'BOOL': sqltypes.BOOLEAN, - 'BOOLEAN': sqltypes.BOOLEAN, - 'CHAR': sqltypes.CHAR, - 'DATE': sqltypes.DATE, - 'DATE_CHAR': sqltypes.DATE, - 'DATETIME': sqltypes.DATETIME, - 'DATETIME_CHAR': sqltypes.DATETIME, - 'DOUBLE': sqltypes.FLOAT, - 'DECIMAL': sqltypes.DECIMAL, - 'FLOAT': sqltypes.FLOAT, - 'INT': sqltypes.INTEGER, - 'INTEGER': sqltypes.INTEGER, - 'JSON': JSON, - 'NUMERIC': sqltypes.NUMERIC, - 'REAL': sqltypes.REAL, - 'SMALLINT': sqltypes.SMALLINT, - 'TEXT': sqltypes.TEXT, - 'TIME': sqltypes.TIME, - 'TIME_CHAR': sqltypes.TIME, - 'TIMESTAMP': sqltypes.TIMESTAMP, - 'VARCHAR': sqltypes.VARCHAR, - 'NVARCHAR': sqltypes.NVARCHAR, - 'NCHAR': sqltypes.NCHAR, + "BIGINT": sqltypes.BIGINT, + "BLOB": sqltypes.BLOB, + "BOOL": sqltypes.BOOLEAN, + "BOOLEAN": sqltypes.BOOLEAN, + "CHAR": sqltypes.CHAR, + "DATE": sqltypes.DATE, + "DATE_CHAR": sqltypes.DATE, + "DATETIME": sqltypes.DATETIME, + "DATETIME_CHAR": sqltypes.DATETIME, + "DOUBLE": sqltypes.FLOAT, + "DECIMAL": sqltypes.DECIMAL, + "FLOAT": sqltypes.FLOAT, + "INT": sqltypes.INTEGER, + "INTEGER": sqltypes.INTEGER, + "JSON": JSON, + "NUMERIC": sqltypes.NUMERIC, + "REAL": sqltypes.REAL, + "SMALLINT": sqltypes.SMALLINT, + "TEXT": sqltypes.TEXT, + "TIME": sqltypes.TIME, + "TIME_CHAR": sqltypes.TIME, + "TIMESTAMP": sqltypes.TIMESTAMP, + "VARCHAR": sqltypes.VARCHAR, + "NVARCHAR": sqltypes.NVARCHAR, + "NCHAR": sqltypes.NCHAR, } @@ -897,17 +935,18 @@ class SQLiteCompiler(compiler.SQLCompiler): extract_map = util.update_copy( compiler.SQLCompiler.extract_map, { - 'month': '%m', - 'day': '%d', - 'year': '%Y', - 'second': '%S', - 'hour': '%H', - 'doy': '%j', - 'minute': '%M', - 'epoch': '%s', - 'dow': '%w', - 'week': '%W', - }) + "month": "%m", + "day": "%d", + "year": "%Y", + "second": "%S", + "hour": "%H", + "doy": "%j", + "minute": "%M", + "epoch": "%s", + "dow": "%w", + "week": "%W", + }, + ) def visit_now_func(self, fn, **kw): return "CURRENT_TIMESTAMP" @@ -916,10 +955,10 @@ class SQLiteCompiler(compiler.SQLCompiler): return 'DATETIME(CURRENT_TIMESTAMP, "localtime")' def visit_true(self, expr, **kw): - return '1' + return "1" def visit_false(self, expr, **kw): - return '0' + return "0" def visit_char_length_func(self, fn, **kw): return "length%s" % self.function_argspec(fn) @@ -934,11 +973,12 @@ class SQLiteCompiler(compiler.SQLCompiler): try: return "CAST(STRFTIME('%s', %s) AS INTEGER)" % ( self.extract_map[extract.field], - self.process(extract.expr, **kw) + self.process(extract.expr, **kw), ) except KeyError: raise exc.CompileError( - "%s is not a valid extract argument." % extract.field) + "%s is not a valid extract argument." % extract.field + ) def limit_clause(self, select, **kw): text = "" @@ -954,35 +994,41 @@ class SQLiteCompiler(compiler.SQLCompiler): def for_update_clause(self, select, **kw): # sqlite has no "FOR UPDATE" AFAICT - return '' + return "" def visit_is_distinct_from_binary(self, binary, operator, **kw): - return "%s IS NOT %s" % (self.process(binary.left), - self.process(binary.right)) + return "%s IS NOT %s" % ( + self.process(binary.left), + self.process(binary.right), + ) def visit_isnot_distinct_from_binary(self, binary, operator, **kw): - return "%s IS %s" % (self.process(binary.left), - self.process(binary.right)) + return "%s IS %s" % ( + self.process(binary.left), + self.process(binary.right), + ) def visit_json_getitem_op_binary(self, binary, operator, **kw): return "JSON_QUOTE(JSON_EXTRACT(%s, %s))" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_json_path_getitem_op_binary(self, binary, operator, **kw): return "JSON_QUOTE(JSON_EXTRACT(%s, %s))" % ( self.process(binary.left, **kw), - self.process(binary.right, **kw)) + self.process(binary.right, **kw), + ) def visit_empty_set_expr(self, type_): - return 'SELECT 1 FROM (SELECT 1) WHERE 1!=1' + return "SELECT 1 FROM (SELECT 1) WHERE 1!=1" class SQLiteDDLCompiler(compiler.DDLCompiler): - def get_column_specification(self, column, **kwargs): coltype = self.dialect.type_compiler.process( - column.type, type_expression=column) + column.type, type_expression=column + ) colspec = self.preparer.format_column(column) + " " + coltype default = self.get_column_default_string(column) if default is not None: @@ -991,29 +1037,33 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): if not column.nullable: colspec += " NOT NULL" - on_conflict_clause = column.dialect_options['sqlite'][ - 'on_conflict_not_null'] + on_conflict_clause = column.dialect_options["sqlite"][ + "on_conflict_not_null" + ] if on_conflict_clause is not None: colspec += " ON CONFLICT " + on_conflict_clause if column.primary_key: if ( - column.autoincrement is True and - len(column.table.primary_key.columns) != 1 + column.autoincrement is True + and len(column.table.primary_key.columns) != 1 ): raise exc.CompileError( "SQLite does not support autoincrement for " - "composite primary keys") + "composite primary keys" + ) - if (column.table.dialect_options['sqlite']['autoincrement'] and - len(column.table.primary_key.columns) == 1 and - issubclass( - column.type._type_affinity, sqltypes.Integer) and - not column.foreign_keys): + if ( + column.table.dialect_options["sqlite"]["autoincrement"] + and len(column.table.primary_key.columns) == 1 + and issubclass(column.type._type_affinity, sqltypes.Integer) + and not column.foreign_keys + ): colspec += " PRIMARY KEY" - on_conflict_clause = column.dialect_options['sqlite'][ - 'on_conflict_primary_key'] + on_conflict_clause = column.dialect_options["sqlite"][ + "on_conflict_primary_key" + ] if on_conflict_clause is not None: colspec += " ON CONFLICT " + on_conflict_clause @@ -1027,21 +1077,25 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): # with the column itself. if len(constraint.columns) == 1: c = list(constraint)[0] - if (c.primary_key and - c.table.dialect_options['sqlite']['autoincrement'] and - issubclass(c.type._type_affinity, sqltypes.Integer) and - not c.foreign_keys): + if ( + c.primary_key + and c.table.dialect_options["sqlite"]["autoincrement"] + and issubclass(c.type._type_affinity, sqltypes.Integer) + and not c.foreign_keys + ): return None - text = super( - SQLiteDDLCompiler, - self).visit_primary_key_constraint(constraint) + text = super(SQLiteDDLCompiler, self).visit_primary_key_constraint( + constraint + ) - on_conflict_clause = constraint.dialect_options['sqlite'][ - 'on_conflict'] + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] if on_conflict_clause is None and len(constraint.columns) == 1: - on_conflict_clause = list(constraint)[0].\ - dialect_options['sqlite']['on_conflict_primary_key'] + on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][ + "on_conflict_primary_key" + ] if on_conflict_clause is not None: text += " ON CONFLICT " + on_conflict_clause @@ -1049,15 +1103,17 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text def visit_unique_constraint(self, constraint): - text = super( - SQLiteDDLCompiler, - self).visit_unique_constraint(constraint) + text = super(SQLiteDDLCompiler, self).visit_unique_constraint( + constraint + ) - on_conflict_clause = constraint.dialect_options['sqlite'][ - 'on_conflict'] + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] if on_conflict_clause is None and len(constraint.columns) == 1: - on_conflict_clause = list(constraint)[0].\ - dialect_options['sqlite']['on_conflict_unique'] + on_conflict_clause = list(constraint)[0].dialect_options["sqlite"][ + "on_conflict_unique" + ] if on_conflict_clause is not None: text += " ON CONFLICT " + on_conflict_clause @@ -1065,12 +1121,13 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text def visit_check_constraint(self, constraint): - text = super( - SQLiteDDLCompiler, - self).visit_check_constraint(constraint) + text = super(SQLiteDDLCompiler, self).visit_check_constraint( + constraint + ) - on_conflict_clause = constraint.dialect_options['sqlite'][ - 'on_conflict'] + on_conflict_clause = constraint.dialect_options["sqlite"][ + "on_conflict" + ] if on_conflict_clause is not None: text += " ON CONFLICT " + on_conflict_clause @@ -1078,14 +1135,15 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): return text def visit_column_check_constraint(self, constraint): - text = super( - SQLiteDDLCompiler, - self).visit_column_check_constraint(constraint) + text = super(SQLiteDDLCompiler, self).visit_column_check_constraint( + constraint + ) - if constraint.dialect_options['sqlite']['on_conflict'] is not None: + if constraint.dialect_options["sqlite"]["on_conflict"] is not None: raise exc.CompileError( "SQLite does not support on conflict clause for " - "column check constraint") + "column check constraint" + ) return text @@ -1097,40 +1155,40 @@ class SQLiteDDLCompiler(compiler.DDLCompiler): if local_table.schema != remote_table.schema: return None else: - return super( - SQLiteDDLCompiler, - self).visit_foreign_key_constraint(constraint) + return super(SQLiteDDLCompiler, self).visit_foreign_key_constraint( + constraint + ) def define_constraint_remote_table(self, constraint, table, preparer): """Format the remote table clause of a CREATE CONSTRAINT clause.""" return preparer.format_table(table, use_schema=False) - def visit_create_index(self, create, include_schema=False, - include_table_schema=True): + def visit_create_index( + self, create, include_schema=False, include_table_schema=True + ): index = create.element self._verify_index_table(index) preparer = self.preparer text = "CREATE " if index.unique: text += "UNIQUE " - text += "INDEX %s ON %s (%s)" \ - % ( - self._prepared_index_name(index, - include_schema=True), - preparer.format_table(index.table, - use_schema=False), - ', '.join( - self.sql_compiler.process( - expr, include_table=False, literal_binds=True) for - expr in index.expressions) - ) + text += "INDEX %s ON %s (%s)" % ( + self._prepared_index_name(index, include_schema=True), + preparer.format_table(index.table, use_schema=False), + ", ".join( + self.sql_compiler.process( + expr, include_table=False, literal_binds=True + ) + for expr in index.expressions + ), + ) whereclause = index.dialect_options["sqlite"]["where"] if whereclause is not None: where_compiled = self.sql_compiler.process( - whereclause, include_table=False, - literal_binds=True) + whereclause, include_table=False, literal_binds=True + ) text += " WHERE " + where_compiled return text @@ -1141,22 +1199,28 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler): return self.visit_BLOB(type_) def visit_DATETIME(self, type_, **kw): - if not isinstance(type_, _DateTimeMixin) or \ - type_.format_is_text_affinity: + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): return super(SQLiteTypeCompiler, self).visit_DATETIME(type_) else: return "DATETIME_CHAR" def visit_DATE(self, type_, **kw): - if not isinstance(type_, _DateTimeMixin) or \ - type_.format_is_text_affinity: + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): return super(SQLiteTypeCompiler, self).visit_DATE(type_) else: return "DATE_CHAR" def visit_TIME(self, type_, **kw): - if not isinstance(type_, _DateTimeMixin) or \ - type_.format_is_text_affinity: + if ( + not isinstance(type_, _DateTimeMixin) + or type_.format_is_text_affinity + ): return super(SQLiteTypeCompiler, self).visit_TIME(type_) else: return "TIME_CHAR" @@ -1169,33 +1233,135 @@ class SQLiteTypeCompiler(compiler.GenericTypeCompiler): class SQLiteIdentifierPreparer(compiler.IdentifierPreparer): - reserved_words = set([ - 'add', 'after', 'all', 'alter', 'analyze', 'and', 'as', 'asc', - 'attach', 'autoincrement', 'before', 'begin', 'between', 'by', - 'cascade', 'case', 'cast', 'check', 'collate', 'column', 'commit', - 'conflict', 'constraint', 'create', 'cross', 'current_date', - 'current_time', 'current_timestamp', 'database', 'default', - 'deferrable', 'deferred', 'delete', 'desc', 'detach', 'distinct', - 'drop', 'each', 'else', 'end', 'escape', 'except', 'exclusive', - 'explain', 'false', 'fail', 'for', 'foreign', 'from', 'full', 'glob', - 'group', 'having', 'if', 'ignore', 'immediate', 'in', 'index', - 'indexed', 'initially', 'inner', 'insert', 'instead', 'intersect', - 'into', 'is', 'isnull', 'join', 'key', 'left', 'like', 'limit', - 'match', 'natural', 'not', 'notnull', 'null', 'of', 'offset', 'on', - 'or', 'order', 'outer', 'plan', 'pragma', 'primary', 'query', - 'raise', 'references', 'reindex', 'rename', 'replace', 'restrict', - 'right', 'rollback', 'row', 'select', 'set', 'table', 'temp', - 'temporary', 'then', 'to', 'transaction', 'trigger', 'true', 'union', - 'unique', 'update', 'using', 'vacuum', 'values', 'view', 'virtual', - 'when', 'where', - ]) + reserved_words = set( + [ + "add", + "after", + "all", + "alter", + "analyze", + "and", + "as", + "asc", + "attach", + "autoincrement", + "before", + "begin", + "between", + "by", + "cascade", + "case", + "cast", + "check", + "collate", + "column", + "commit", + "conflict", + "constraint", + "create", + "cross", + "current_date", + "current_time", + "current_timestamp", + "database", + "default", + "deferrable", + "deferred", + "delete", + "desc", + "detach", + "distinct", + "drop", + "each", + "else", + "end", + "escape", + "except", + "exclusive", + "explain", + "false", + "fail", + "for", + "foreign", + "from", + "full", + "glob", + "group", + "having", + "if", + "ignore", + "immediate", + "in", + "index", + "indexed", + "initially", + "inner", + "insert", + "instead", + "intersect", + "into", + "is", + "isnull", + "join", + "key", + "left", + "like", + "limit", + "match", + "natural", + "not", + "notnull", + "null", + "of", + "offset", + "on", + "or", + "order", + "outer", + "plan", + "pragma", + "primary", + "query", + "raise", + "references", + "reindex", + "rename", + "replace", + "restrict", + "right", + "rollback", + "row", + "select", + "set", + "table", + "temp", + "temporary", + "then", + "to", + "transaction", + "trigger", + "true", + "union", + "unique", + "update", + "using", + "vacuum", + "values", + "view", + "virtual", + "when", + "where", + ] + ) class SQLiteExecutionContext(default.DefaultExecutionContext): @util.memoized_property def _preserve_raw_colnames(self): - return not self.dialect._broken_dotted_colnames or \ - self.execution_options.get("sqlite_raw_colnames", False) + return ( + not self.dialect._broken_dotted_colnames + or self.execution_options.get("sqlite_raw_colnames", False) + ) def _translate_colname(self, colname): # TODO: detect SQLite version 3.10.0 or greater; @@ -1212,7 +1378,7 @@ class SQLiteExecutionContext(default.DefaultExecutionContext): class SQLiteDialect(default.DefaultDialect): - name = 'sqlite' + name = "sqlite" supports_alter = False supports_unicode_statements = True supports_unicode_binds = True @@ -1221,7 +1387,7 @@ class SQLiteDialect(default.DefaultDialect): supports_cast = True supports_multivalues_insert = True - default_paramstyle = 'qmark' + default_paramstyle = "qmark" execution_ctx_cls = SQLiteExecutionContext statement_compiler = SQLiteCompiler ddl_compiler = SQLiteDDLCompiler @@ -1235,27 +1401,30 @@ class SQLiteDialect(default.DefaultDialect): supports_default_values = True construct_arguments = [ - (sa_schema.Table, { - "autoincrement": False - }), - (sa_schema.Index, { - "where": None, - }), - (sa_schema.Column, { - "on_conflict_primary_key": None, - "on_conflict_not_null": None, - "on_conflict_unique": None, - }), - (sa_schema.Constraint, { - "on_conflict": None, - }), + (sa_schema.Table, {"autoincrement": False}), + (sa_schema.Index, {"where": None}), + ( + sa_schema.Column, + { + "on_conflict_primary_key": None, + "on_conflict_not_null": None, + "on_conflict_unique": None, + }, + ), + (sa_schema.Constraint, {"on_conflict": None}), ] _broken_fk_pragma_quotes = False _broken_dotted_colnames = False - def __init__(self, isolation_level=None, native_datetime=False, - _json_serializer=None, _json_deserializer=None, **kwargs): + def __init__( + self, + isolation_level=None, + native_datetime=False, + _json_serializer=None, + _json_deserializer=None, + **kwargs + ): default.DefaultDialect.__init__(self, **kwargs) self.isolation_level = isolation_level self._json_serializer = _json_serializer @@ -1269,35 +1438,42 @@ class SQLiteDialect(default.DefaultDialect): if self.dbapi is not None: self.supports_right_nested_joins = ( - self.dbapi.sqlite_version_info >= (3, 7, 16)) - self._broken_dotted_colnames = ( - self.dbapi.sqlite_version_info < (3, 10, 0) + self.dbapi.sqlite_version_info >= (3, 7, 16) + ) + self._broken_dotted_colnames = self.dbapi.sqlite_version_info < ( + 3, + 10, + 0, + ) + self.supports_default_values = self.dbapi.sqlite_version_info >= ( + 3, + 3, + 8, ) - self.supports_default_values = ( - self.dbapi.sqlite_version_info >= (3, 3, 8)) - self.supports_cast = ( - self.dbapi.sqlite_version_info >= (3, 2, 3)) + self.supports_cast = self.dbapi.sqlite_version_info >= (3, 2, 3) self.supports_multivalues_insert = ( # http://www.sqlite.org/releaselog/3_7_11.html - self.dbapi.sqlite_version_info >= (3, 7, 11)) + self.dbapi.sqlite_version_info + >= (3, 7, 11) + ) # see http://www.sqlalchemy.org/trac/ticket/2568 # as well as http://www.sqlite.org/src/info/600482d161 - self._broken_fk_pragma_quotes = ( - self.dbapi.sqlite_version_info < (3, 6, 14)) + self._broken_fk_pragma_quotes = self.dbapi.sqlite_version_info < ( + 3, + 6, + 14, + ) - _isolation_lookup = { - 'READ UNCOMMITTED': 1, - 'SERIALIZABLE': 0, - } + _isolation_lookup = {"READ UNCOMMITTED": 1, "SERIALIZABLE": 0} def set_isolation_level(self, connection, level): try: - isolation_level = self._isolation_lookup[level.replace('_', ' ')] + isolation_level = self._isolation_lookup[level.replace("_", " ")] except KeyError: raise exc.ArgumentError( "Invalid value '%s' for isolation_level. " - "Valid isolation levels for %s are %s" % - (level, self.name, ", ".join(self._isolation_lookup)) + "Valid isolation levels for %s are %s" + % (level, self.name, ", ".join(self._isolation_lookup)) ) cursor = connection.cursor() cursor.execute("PRAGMA read_uncommitted = %d" % isolation_level) @@ -1305,7 +1481,7 @@ class SQLiteDialect(default.DefaultDialect): def get_isolation_level(self, connection): cursor = connection.cursor() - cursor.execute('PRAGMA read_uncommitted') + cursor.execute("PRAGMA read_uncommitted") res = cursor.fetchone() if res: value = res[0] @@ -1327,8 +1503,10 @@ class SQLiteDialect(default.DefaultDialect): def on_connect(self): if self.isolation_level is not None: + def connect(conn): self.set_isolation_level(conn, self.isolation_level) + return connect else: return None @@ -1344,44 +1522,51 @@ class SQLiteDialect(default.DefaultDialect): def get_table_names(self, connection, schema=None, **kw): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = '%s.sqlite_master' % qschema + master = "%s.sqlite_master" % qschema else: master = "sqlite_master" - s = ("SELECT name FROM %s " - "WHERE type='table' ORDER BY name") % (master,) + s = ("SELECT name FROM %s " "WHERE type='table' ORDER BY name") % ( + master, + ) rs = connection.execute(s) return [row[0] for row in rs] @reflection.cache def get_temp_table_names(self, connection, **kw): - s = "SELECT name FROM sqlite_temp_master "\ + s = ( + "SELECT name FROM sqlite_temp_master " "WHERE type='table' ORDER BY name " + ) rs = connection.execute(s) return [row[0] for row in rs] @reflection.cache def get_temp_view_names(self, connection, **kw): - s = "SELECT name FROM sqlite_temp_master "\ + s = ( + "SELECT name FROM sqlite_temp_master " "WHERE type='view' ORDER BY name " + ) rs = connection.execute(s) return [row[0] for row in rs] def has_table(self, connection, table_name, schema=None): info = self._get_table_pragma( - connection, "table_info", table_name, schema=schema) + connection, "table_info", table_name, schema=schema + ) return bool(info) @reflection.cache def get_view_names(self, connection, schema=None, **kw): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = '%s.sqlite_master' % qschema + master = "%s.sqlite_master" % qschema else: master = "sqlite_master" - s = ("SELECT name FROM %s " - "WHERE type='view' ORDER BY name") % (master,) + s = ("SELECT name FROM %s " "WHERE type='view' ORDER BY name") % ( + master, + ) rs = connection.execute(s) return [row[0] for row in rs] @@ -1390,21 +1575,27 @@ class SQLiteDialect(default.DefaultDialect): def get_view_definition(self, connection, view_name, schema=None, **kw): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = '%s.sqlite_master' % qschema - s = ("SELECT sql FROM %s WHERE name = '%s'" - "AND type='view'") % (master, view_name) + master = "%s.sqlite_master" % qschema + s = ("SELECT sql FROM %s WHERE name = '%s'" "AND type='view'") % ( + master, + view_name, + ) rs = connection.execute(s) else: try: - s = ("SELECT sql FROM " - " (SELECT * FROM sqlite_master UNION ALL " - " SELECT * FROM sqlite_temp_master) " - "WHERE name = '%s' " - "AND type='view'") % view_name + s = ( + "SELECT sql FROM " + " (SELECT * FROM sqlite_master UNION ALL " + " SELECT * FROM sqlite_temp_master) " + "WHERE name = '%s' " + "AND type='view'" + ) % view_name rs = connection.execute(s) except exc.DBAPIError: - s = ("SELECT sql FROM sqlite_master WHERE name = '%s' " - "AND type='view'") % view_name + s = ( + "SELECT sql FROM sqlite_master WHERE name = '%s' " + "AND type='view'" + ) % view_name rs = connection.execute(s) result = rs.fetchall() @@ -1414,15 +1605,24 @@ class SQLiteDialect(default.DefaultDialect): @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): info = self._get_table_pragma( - connection, "table_info", table_name, schema=schema) + connection, "table_info", table_name, schema=schema + ) columns = [] for row in info: (name, type_, nullable, default, primary_key) = ( - row[1], row[2].upper(), not row[3], row[4], row[5]) + row[1], + row[2].upper(), + not row[3], + row[4], + row[5], + ) - columns.append(self._get_column_info(name, type_, nullable, - default, primary_key)) + columns.append( + self._get_column_info( + name, type_, nullable, default, primary_key + ) + ) return columns def _get_column_info(self, name, type_, nullable, default, primary_key): @@ -1432,12 +1632,12 @@ class SQLiteDialect(default.DefaultDialect): default = util.text_type(default) return { - 'name': name, - 'type': coltype, - 'nullable': nullable, - 'default': default, - 'autoincrement': 'auto', - 'primary_key': primary_key, + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": "auto", + "primary_key": primary_key, } def _resolve_type_affinity(self, type_): @@ -1457,36 +1657,37 @@ class SQLiteDialect(default.DefaultDialect): DATE and DOUBLE). """ - match = re.match(r'([\w ]+)(\(.*?\))?', type_) + match = re.match(r"([\w ]+)(\(.*?\))?", type_) if match: coltype = match.group(1) args = match.group(2) else: - coltype = '' - args = '' + coltype = "" + args = "" if coltype in self.ischema_names: coltype = self.ischema_names[coltype] - elif 'INT' in coltype: + elif "INT" in coltype: coltype = sqltypes.INTEGER - elif 'CHAR' in coltype or 'CLOB' in coltype or 'TEXT' in coltype: + elif "CHAR" in coltype or "CLOB" in coltype or "TEXT" in coltype: coltype = sqltypes.TEXT - elif 'BLOB' in coltype or not coltype: + elif "BLOB" in coltype or not coltype: coltype = sqltypes.NullType - elif 'REAL' in coltype or 'FLOA' in coltype or 'DOUB' in coltype: + elif "REAL" in coltype or "FLOA" in coltype or "DOUB" in coltype: coltype = sqltypes.REAL else: coltype = sqltypes.NUMERIC if args is not None: - args = re.findall(r'(\d+)', args) + args = re.findall(r"(\d+)", args) try: coltype = coltype(*[int(a) for a in args]) except TypeError: util.warn( "Could not instantiate type %s with " - "reflected arguments %s; using no arguments." % - (coltype, args)) + "reflected arguments %s; using no arguments." + % (coltype, args) + ) coltype = coltype() else: coltype = coltype() @@ -1498,58 +1699,59 @@ class SQLiteDialect(default.DefaultDialect): constraint_name = None table_data = self._get_table_sql(connection, table_name, schema=schema) if table_data: - PK_PATTERN = r'CONSTRAINT (\w+) PRIMARY KEY' + PK_PATTERN = r"CONSTRAINT (\w+) PRIMARY KEY" result = re.search(PK_PATTERN, table_data, re.I) constraint_name = result.group(1) if result else None cols = self.get_columns(connection, table_name, schema, **kw) pkeys = [] for col in cols: - if col['primary_key']: - pkeys.append(col['name']) + if col["primary_key"]: + pkeys.append(col["name"]) - return {'constrained_columns': pkeys, 'name': constraint_name} + return {"constrained_columns": pkeys, "name": constraint_name} @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): # sqlite makes this *extremely difficult*. # First, use the pragma to get the actual FKs. pragma_fks = self._get_table_pragma( - connection, "foreign_key_list", - table_name, schema=schema + connection, "foreign_key_list", table_name, schema=schema ) fks = {} for row in pragma_fks: - (numerical_id, rtbl, lcol, rcol) = ( - row[0], row[2], row[3], row[4]) + (numerical_id, rtbl, lcol, rcol) = (row[0], row[2], row[3], row[4]) if rcol is None: rcol = lcol if self._broken_fk_pragma_quotes: - rtbl = re.sub(r'^[\"\[`\']|[\"\]`\']$', '', rtbl) + rtbl = re.sub(r"^[\"\[`\']|[\"\]`\']$", "", rtbl) if numerical_id in fks: fk = fks[numerical_id] else: fk = fks[numerical_id] = { - 'name': None, - 'constrained_columns': [], - 'referred_schema': schema, - 'referred_table': rtbl, - 'referred_columns': [], - 'options': {} + "name": None, + "constrained_columns": [], + "referred_schema": schema, + "referred_table": rtbl, + "referred_columns": [], + "options": {}, } fks[numerical_id] = fk - fk['constrained_columns'].append(lcol) - fk['referred_columns'].append(rcol) + fk["constrained_columns"].append(lcol) + fk["referred_columns"].append(rcol) def fk_sig(constrained_columns, referred_table, referred_columns): - return tuple(constrained_columns) + (referred_table,) + \ - tuple(referred_columns) + return ( + tuple(constrained_columns) + + (referred_table,) + + tuple(referred_columns) + ) # then, parse the actual SQL and attempt to find DDL that matches # the names as well. SQLite saves the DDL in whatever format @@ -1558,10 +1760,13 @@ class SQLiteDialect(default.DefaultDialect): keys_by_signature = dict( ( fk_sig( - fk['constrained_columns'], - fk['referred_table'], fk['referred_columns']), - fk - ) for fk in fks.values() + fk["constrained_columns"], + fk["referred_table"], + fk["referred_columns"], + ), + fk, + ) + for fk in fks.values() ) table_data = self._get_table_sql(connection, table_name, schema=schema) @@ -1571,55 +1776,66 @@ class SQLiteDialect(default.DefaultDialect): def parse_fks(): FK_PATTERN = ( - r'(?:CONSTRAINT (\w+) +)?' - r'FOREIGN KEY *\( *(.+?) *\) +' + r"(?:CONSTRAINT (\w+) +)?" + r"FOREIGN KEY *\( *(.+?) *\) +" r'REFERENCES +(?:(?:"(.+?)")|([a-z0-9_]+)) *\((.+?)\) *' - r'((?:ON (?:DELETE|UPDATE) ' - r'(?:SET NULL|SET DEFAULT|CASCADE|RESTRICT|NO ACTION) *)*)' + r"((?:ON (?:DELETE|UPDATE) " + r"(?:SET NULL|SET DEFAULT|CASCADE|RESTRICT|NO ACTION) *)*)" ) for match in re.finditer(FK_PATTERN, table_data, re.I): ( - constraint_name, constrained_columns, - referred_quoted_name, referred_name, - referred_columns, onupdatedelete) = \ - match.group(1, 2, 3, 4, 5, 6) + constraint_name, + constrained_columns, + referred_quoted_name, + referred_name, + referred_columns, + onupdatedelete, + ) = match.group(1, 2, 3, 4, 5, 6) constrained_columns = list( - self._find_cols_in_sig(constrained_columns)) + self._find_cols_in_sig(constrained_columns) + ) if not referred_columns: referred_columns = constrained_columns else: referred_columns = list( - self._find_cols_in_sig(referred_columns)) + self._find_cols_in_sig(referred_columns) + ) referred_name = referred_quoted_name or referred_name options = {} for token in re.split(r" *\bON\b *", onupdatedelete.upper()): if token.startswith("DELETE"): - options['ondelete'] = token[6:].strip() + options["ondelete"] = token[6:].strip() elif token.startswith("UPDATE"): options["onupdate"] = token[6:].strip() yield ( - constraint_name, constrained_columns, - referred_name, referred_columns, options) + constraint_name, + constrained_columns, + referred_name, + referred_columns, + options, + ) + fkeys = [] for ( - constraint_name, constrained_columns, - referred_name, referred_columns, options) in parse_fks(): - sig = fk_sig( - constrained_columns, referred_name, referred_columns) + constraint_name, + constrained_columns, + referred_name, + referred_columns, + options, + ) in parse_fks(): + sig = fk_sig(constrained_columns, referred_name, referred_columns) if sig not in keys_by_signature: util.warn( "WARNING: SQL-parsed foreign key constraint " "'%s' could not be located in PRAGMA " - "foreign_keys for table %s" % ( - sig, - table_name - )) + "foreign_keys for table %s" % (sig, table_name) + ) continue key = keys_by_signature.pop(sig) - key['name'] = constraint_name - key['options'] = options + key["name"] = constraint_name + key["options"] = options fkeys.append(key) # assume the remainders are the unnamed, inline constraints, just # use them as is as it's extremely difficult to parse inline @@ -1632,20 +1848,26 @@ class SQLiteDialect(default.DefaultDialect): yield match.group(1) or match.group(2) @reflection.cache - def get_unique_constraints(self, connection, table_name, - schema=None, **kw): + def get_unique_constraints( + self, connection, table_name, schema=None, **kw + ): auto_index_by_sig = {} for idx in self.get_indexes( - connection, table_name, schema=schema, - include_auto_indexes=True, **kw): - if not idx['name'].startswith("sqlite_autoindex"): + connection, + table_name, + schema=schema, + include_auto_indexes=True, + **kw + ): + if not idx["name"].startswith("sqlite_autoindex"): continue - sig = tuple(idx['column_names']) + sig = tuple(idx["column_names"]) auto_index_by_sig[sig] = idx table_data = self._get_table_sql( - connection, table_name, schema=schema, **kw) + connection, table_name, schema=schema, **kw + ) if not table_data: return [] @@ -1654,8 +1876,8 @@ class SQLiteDialect(default.DefaultDialect): def parse_uqs(): UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)' INLINE_UNIQUE_PATTERN = ( - r'(?:(".+?")|([a-z0-9]+)) ' - r'+[a-z0-9_ ]+? +UNIQUE') + r'(?:(".+?")|([a-z0-9]+)) ' r"+[a-z0-9_ ]+? +UNIQUE" + ) for match in re.finditer(UNIQUE_PATTERN, table_data, re.I): name, cols = match.group(1, 2) @@ -1666,34 +1888,29 @@ class SQLiteDialect(default.DefaultDialect): # are kind of the same thing :) for match in re.finditer(INLINE_UNIQUE_PATTERN, table_data, re.I): cols = list( - self._find_cols_in_sig(match.group(1) or match.group(2))) + self._find_cols_in_sig(match.group(1) or match.group(2)) + ) yield None, cols for name, cols in parse_uqs(): sig = tuple(cols) if sig in auto_index_by_sig: auto_index_by_sig.pop(sig) - parsed_constraint = { - 'name': name, - 'column_names': cols - } + parsed_constraint = {"name": name, "column_names": cols} unique_constraints.append(parsed_constraint) # NOTE: auto_index_by_sig might not be empty here, # the PRIMARY KEY may have an entry. return unique_constraints @reflection.cache - def get_check_constraints(self, connection, table_name, - schema=None, **kw): + def get_check_constraints(self, connection, table_name, schema=None, **kw): table_data = self._get_table_sql( - connection, table_name, schema=schema, **kw) + connection, table_name, schema=schema, **kw + ) if not table_data: return [] - CHECK_PATTERN = ( - r'(?:CONSTRAINT (\w+) +)?' - r'CHECK *\( *(.+) *\),? *' - ) + CHECK_PATTERN = r"(?:CONSTRAINT (\w+) +)?" r"CHECK *\( *(.+) *\),? *" check_constraints = [] # NOTE: we aren't using re.S here because we actually are # taking advantage of each CHECK constraint being all on one @@ -1701,25 +1918,26 @@ class SQLiteDialect(default.DefaultDialect): # necessarily makes assumptions as to how the CREATE TABLE # was emitted. for match in re.finditer(CHECK_PATTERN, table_data, re.I): - check_constraints.append({ - 'sqltext': match.group(2), - 'name': match.group(1) - }) + check_constraints.append( + {"sqltext": match.group(2), "name": match.group(1)} + ) return check_constraints @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): pragma_indexes = self._get_table_pragma( - connection, "index_list", table_name, schema=schema) + connection, "index_list", table_name, schema=schema + ) indexes = [] - include_auto_indexes = kw.pop('include_auto_indexes', False) + include_auto_indexes = kw.pop("include_auto_indexes", False) for row in pragma_indexes: # ignore implicit primary key index. # http://www.mail-archive.com/sqlite-users@sqlite.org/msg30517.html - if (not include_auto_indexes and - row[1].startswith('sqlite_autoindex')): + if not include_auto_indexes and row[1].startswith( + "sqlite_autoindex" + ): continue indexes.append(dict(name=row[1], column_names=[], unique=row[2])) @@ -1727,34 +1945,38 @@ class SQLiteDialect(default.DefaultDialect): # loop thru unique indexes to get the column names. for idx in indexes: pragma_index = self._get_table_pragma( - connection, "index_info", idx['name']) + connection, "index_info", idx["name"] + ) for row in pragma_index: - idx['column_names'].append(row[2]) + idx["column_names"].append(row[2]) return indexes @reflection.cache def _get_table_sql(self, connection, table_name, schema=None, **kw): if schema: schema_expr = "%s." % ( - self.identifier_preparer.quote_identifier(schema)) + self.identifier_preparer.quote_identifier(schema) + ) else: schema_expr = "" try: - s = ("SELECT sql FROM " - " (SELECT * FROM %(schema)ssqlite_master UNION ALL " - " SELECT * FROM %(schema)ssqlite_temp_master) " - "WHERE name = '%(table)s' " - "AND type = 'table'" % { - "schema": schema_expr, - "table": table_name}) + s = ( + "SELECT sql FROM " + " (SELECT * FROM %(schema)ssqlite_master UNION ALL " + " SELECT * FROM %(schema)ssqlite_temp_master) " + "WHERE name = '%(table)s' " + "AND type = 'table'" + % {"schema": schema_expr, "table": table_name} + ) rs = connection.execute(s) except exc.DBAPIError: - s = ("SELECT sql FROM %(schema)ssqlite_master " - "WHERE name = '%(table)s' " - "AND type = 'table'" % { - "schema": schema_expr, - "table": table_name}) + s = ( + "SELECT sql FROM %(schema)ssqlite_master " + "WHERE name = '%(table)s' " + "AND type = 'table'" + % {"schema": schema_expr, "table": table_name} + ) rs = connection.execute(s) return rs.scalar() diff --git a/lib/sqlalchemy/dialects/sqlite/json.py b/lib/sqlalchemy/dialects/sqlite/json.py index 90929fbd8..db185dd4d 100644 --- a/lib/sqlalchemy/dialects/sqlite/json.py +++ b/lib/sqlalchemy/dialects/sqlite/json.py @@ -58,7 +58,6 @@ class _FormatTypeMixin(object): class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): - def _format_value(self, value): if isinstance(value, int): value = "$[%s]" % value @@ -70,8 +69,10 @@ class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType): class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType): def _format_value(self, value): return "$%s" % ( - "".join([ - "[%s]" % elem if isinstance(elem, int) - else '."%s"' % elem for elem in value - ]) + "".join( + [ + "[%s]" % elem if isinstance(elem, int) else '."%s"' % elem + for elem in value + ] + ) ) diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py index 09f2b8009..fca425127 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlcipher.py @@ -82,9 +82,9 @@ from ... import pool class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): - driver = 'pysqlcipher' + driver = "pysqlcipher" - pragmas = ('kdf_iter', 'cipher', 'cipher_page_size', 'cipher_use_hmac') + pragmas = ("kdf_iter", "cipher", "cipher_page_size", "cipher_use_hmac") @classmethod def dbapi(cls): @@ -102,15 +102,13 @@ class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): return pool.SingletonThreadPool def connect(self, *cargs, **cparams): - passphrase = cparams.pop('passphrase', '') + passphrase = cparams.pop("passphrase", "") - pragmas = dict( - (key, cparams.pop(key, None)) for key in - self.pragmas - ) + pragmas = dict((key, cparams.pop(key, None)) for key in self.pragmas) - conn = super(SQLiteDialect_pysqlcipher, self).\ - connect(*cargs, **cparams) + conn = super(SQLiteDialect_pysqlcipher, self).connect( + *cargs, **cparams + ) conn.execute('pragma key="%s"' % passphrase) for prag, value in pragmas.items(): if value is not None: @@ -120,11 +118,17 @@ class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite): def create_connect_args(self, url): super_url = _url.URL( - url.drivername, username=url.username, - host=url.host, database=url.database, query=url.query) - c_args, opts = super(SQLiteDialect_pysqlcipher, self).\ - create_connect_args(super_url) - opts['passphrase'] = url.password + url.drivername, + username=url.username, + host=url.host, + database=url.database, + query=url.query, + ) + c_args, opts = super( + SQLiteDialect_pysqlcipher, self + ).create_connect_args(super_url) + opts["passphrase"] = url.password return c_args, opts + dialect = SQLiteDialect_pysqlcipher diff --git a/lib/sqlalchemy/dialects/sqlite/pysqlite.py b/lib/sqlalchemy/dialects/sqlite/pysqlite.py index 8809962df..e78d76ae6 100644 --- a/lib/sqlalchemy/dialects/sqlite/pysqlite.py +++ b/lib/sqlalchemy/dialects/sqlite/pysqlite.py @@ -301,20 +301,20 @@ class _SQLite_pysqliteDate(DATE): class SQLiteDialect_pysqlite(SQLiteDialect): - default_paramstyle = 'qmark' + default_paramstyle = "qmark" colspecs = util.update_copy( SQLiteDialect.colspecs, { sqltypes.Date: _SQLite_pysqliteDate, sqltypes.TIMESTAMP: _SQLite_pysqliteTimeStamp, - } + }, ) if not util.py2k: description_encoding = None - driver = 'pysqlite' + driver = "pysqlite" def __init__(self, **kwargs): SQLiteDialect.__init__(self, **kwargs) @@ -323,10 +323,13 @@ class SQLiteDialect_pysqlite(SQLiteDialect): sqlite_ver = self.dbapi.version_info if sqlite_ver < (2, 1, 3): util.warn( - ("The installed version of pysqlite2 (%s) is out-dated " - "and will cause errors in some cases. Version 2.1.3 " - "or greater is recommended.") % - '.'.join([str(subver) for subver in sqlite_ver])) + ( + "The installed version of pysqlite2 (%s) is out-dated " + "and will cause errors in some cases. Version 2.1.3 " + "or greater is recommended." + ) + % ".".join([str(subver) for subver in sqlite_ver]) + ) @classmethod def dbapi(cls): @@ -341,7 +344,7 @@ class SQLiteDialect_pysqlite(SQLiteDialect): @classmethod def get_pool_class(cls, url): - if url.database and url.database != ':memory:': + if url.database and url.database != ":memory:": return pool.NullPool else: return pool.SingletonThreadPool @@ -356,22 +359,25 @@ class SQLiteDialect_pysqlite(SQLiteDialect): "Valid SQLite URL forms are:\n" " sqlite:///:memory: (or, sqlite://)\n" " sqlite:///relative/path/to/file.db\n" - " sqlite:////absolute/path/to/file.db" % (url,)) - filename = url.database or ':memory:' - if filename != ':memory:': + " sqlite:////absolute/path/to/file.db" % (url,) + ) + filename = url.database or ":memory:" + if filename != ":memory:": filename = os.path.abspath(filename) opts = url.query.copy() - util.coerce_kw_type(opts, 'timeout', float) - util.coerce_kw_type(opts, 'isolation_level', str) - util.coerce_kw_type(opts, 'detect_types', int) - util.coerce_kw_type(opts, 'check_same_thread', bool) - util.coerce_kw_type(opts, 'cached_statements', int) + util.coerce_kw_type(opts, "timeout", float) + util.coerce_kw_type(opts, "isolation_level", str) + util.coerce_kw_type(opts, "detect_types", int) + util.coerce_kw_type(opts, "check_same_thread", bool) + util.coerce_kw_type(opts, "cached_statements", int) return ([filename], opts) def is_disconnect(self, e, connection, cursor): - return isinstance(e, self.dbapi.ProgrammingError) and \ - "Cannot operate on a closed database." in str(e) + return isinstance( + e, self.dbapi.ProgrammingError + ) and "Cannot operate on a closed database." in str(e) + dialect = SQLiteDialect_pysqlite diff --git a/lib/sqlalchemy/dialects/sybase/__init__.py b/lib/sqlalchemy/dialects/sybase/__init__.py index be434977f..2f55d3bf6 100644 --- a/lib/sqlalchemy/dialects/sybase/__init__.py +++ b/lib/sqlalchemy/dialects/sybase/__init__.py @@ -7,21 +7,61 @@ from . import base, pysybase, pyodbc # noqa -from .base import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\ - TEXT, DATE, DATETIME, FLOAT, NUMERIC,\ - BIGINT, INT, INTEGER, SMALLINT, BINARY,\ - VARBINARY, UNITEXT, UNICHAR, UNIVARCHAR,\ - IMAGE, BIT, MONEY, SMALLMONEY, TINYINT +from .base import ( + CHAR, + VARCHAR, + TIME, + NCHAR, + NVARCHAR, + TEXT, + DATE, + DATETIME, + FLOAT, + NUMERIC, + BIGINT, + INT, + INTEGER, + SMALLINT, + BINARY, + VARBINARY, + UNITEXT, + UNICHAR, + UNIVARCHAR, + IMAGE, + BIT, + MONEY, + SMALLMONEY, + TINYINT, +) # default dialect base.dialect = dialect = pyodbc.dialect __all__ = ( - 'CHAR', 'VARCHAR', 'TIME', 'NCHAR', 'NVARCHAR', - 'TEXT', 'DATE', 'DATETIME', 'FLOAT', 'NUMERIC', - 'BIGINT', 'INT', 'INTEGER', 'SMALLINT', 'BINARY', - 'VARBINARY', 'UNITEXT', 'UNICHAR', 'UNIVARCHAR', - 'IMAGE', 'BIT', 'MONEY', 'SMALLMONEY', 'TINYINT', - 'dialect' + "CHAR", + "VARCHAR", + "TIME", + "NCHAR", + "NVARCHAR", + "TEXT", + "DATE", + "DATETIME", + "FLOAT", + "NUMERIC", + "BIGINT", + "INT", + "INTEGER", + "SMALLINT", + "BINARY", + "VARBINARY", + "UNITEXT", + "UNICHAR", + "UNIVARCHAR", + "IMAGE", + "BIT", + "MONEY", + "SMALLMONEY", + "TINYINT", + "dialect", ) diff --git a/lib/sqlalchemy/dialects/sybase/base.py b/lib/sqlalchemy/dialects/sybase/base.py index 7dd973573..1214a9279 100644 --- a/lib/sqlalchemy/dialects/sybase/base.py +++ b/lib/sqlalchemy/dialects/sybase/base.py @@ -31,70 +31,257 @@ from sqlalchemy.sql import operators as sql_operators from sqlalchemy import schema as sa_schema from sqlalchemy import util, sql, exc -from sqlalchemy.types import CHAR, VARCHAR, TIME, NCHAR, NVARCHAR,\ - TEXT, DATE, DATETIME, FLOAT, NUMERIC,\ - BIGINT, INT, INTEGER, SMALLINT, BINARY,\ - VARBINARY, DECIMAL, TIMESTAMP, Unicode,\ - UnicodeText, REAL - -RESERVED_WORDS = set([ - "add", "all", "alter", "and", - "any", "as", "asc", "backup", - "begin", "between", "bigint", "binary", - "bit", "bottom", "break", "by", - "call", "capability", "cascade", "case", - "cast", "char", "char_convert", "character", - "check", "checkpoint", "close", "comment", - "commit", "connect", "constraint", "contains", - "continue", "convert", "create", "cross", - "cube", "current", "current_timestamp", "current_user", - "cursor", "date", "dbspace", "deallocate", - "dec", "decimal", "declare", "default", - "delete", "deleting", "desc", "distinct", - "do", "double", "drop", "dynamic", - "else", "elseif", "encrypted", "end", - "endif", "escape", "except", "exception", - "exec", "execute", "existing", "exists", - "externlogin", "fetch", "first", "float", - "for", "force", "foreign", "forward", - "from", "full", "goto", "grant", - "group", "having", "holdlock", "identified", - "if", "in", "index", "index_lparen", - "inner", "inout", "insensitive", "insert", - "inserting", "install", "instead", "int", - "integer", "integrated", "intersect", "into", - "iq", "is", "isolation", "join", - "key", "lateral", "left", "like", - "lock", "login", "long", "match", - "membership", "message", "mode", "modify", - "natural", "new", "no", "noholdlock", - "not", "notify", "null", "numeric", - "of", "off", "on", "open", - "option", "options", "or", "order", - "others", "out", "outer", "over", - "passthrough", "precision", "prepare", "primary", - "print", "privileges", "proc", "procedure", - "publication", "raiserror", "readtext", "real", - "reference", "references", "release", "remote", - "remove", "rename", "reorganize", "resource", - "restore", "restrict", "return", "revoke", - "right", "rollback", "rollup", "save", - "savepoint", "scroll", "select", "sensitive", - "session", "set", "setuser", "share", - "smallint", "some", "sqlcode", "sqlstate", - "start", "stop", "subtrans", "subtransaction", - "synchronize", "syntax_error", "table", "temporary", - "then", "time", "timestamp", "tinyint", - "to", "top", "tran", "trigger", - "truncate", "tsequal", "unbounded", "union", - "unique", "unknown", "unsigned", "update", - "updating", "user", "using", "validate", - "values", "varbinary", "varchar", "variable", - "varying", "view", "wait", "waitfor", - "when", "where", "while", "window", - "with", "with_cube", "with_lparen", "with_rollup", - "within", "work", "writetext", -]) +from sqlalchemy.types import ( + CHAR, + VARCHAR, + TIME, + NCHAR, + NVARCHAR, + TEXT, + DATE, + DATETIME, + FLOAT, + NUMERIC, + BIGINT, + INT, + INTEGER, + SMALLINT, + BINARY, + VARBINARY, + DECIMAL, + TIMESTAMP, + Unicode, + UnicodeText, + REAL, +) + +RESERVED_WORDS = set( + [ + "add", + "all", + "alter", + "and", + "any", + "as", + "asc", + "backup", + "begin", + "between", + "bigint", + "binary", + "bit", + "bottom", + "break", + "by", + "call", + "capability", + "cascade", + "case", + "cast", + "char", + "char_convert", + "character", + "check", + "checkpoint", + "close", + "comment", + "commit", + "connect", + "constraint", + "contains", + "continue", + "convert", + "create", + "cross", + "cube", + "current", + "current_timestamp", + "current_user", + "cursor", + "date", + "dbspace", + "deallocate", + "dec", + "decimal", + "declare", + "default", + "delete", + "deleting", + "desc", + "distinct", + "do", + "double", + "drop", + "dynamic", + "else", + "elseif", + "encrypted", + "end", + "endif", + "escape", + "except", + "exception", + "exec", + "execute", + "existing", + "exists", + "externlogin", + "fetch", + "first", + "float", + "for", + "force", + "foreign", + "forward", + "from", + "full", + "goto", + "grant", + "group", + "having", + "holdlock", + "identified", + "if", + "in", + "index", + "index_lparen", + "inner", + "inout", + "insensitive", + "insert", + "inserting", + "install", + "instead", + "int", + "integer", + "integrated", + "intersect", + "into", + "iq", + "is", + "isolation", + "join", + "key", + "lateral", + "left", + "like", + "lock", + "login", + "long", + "match", + "membership", + "message", + "mode", + "modify", + "natural", + "new", + "no", + "noholdlock", + "not", + "notify", + "null", + "numeric", + "of", + "off", + "on", + "open", + "option", + "options", + "or", + "order", + "others", + "out", + "outer", + "over", + "passthrough", + "precision", + "prepare", + "primary", + "print", + "privileges", + "proc", + "procedure", + "publication", + "raiserror", + "readtext", + "real", + "reference", + "references", + "release", + "remote", + "remove", + "rename", + "reorganize", + "resource", + "restore", + "restrict", + "return", + "revoke", + "right", + "rollback", + "rollup", + "save", + "savepoint", + "scroll", + "select", + "sensitive", + "session", + "set", + "setuser", + "share", + "smallint", + "some", + "sqlcode", + "sqlstate", + "start", + "stop", + "subtrans", + "subtransaction", + "synchronize", + "syntax_error", + "table", + "temporary", + "then", + "time", + "timestamp", + "tinyint", + "to", + "top", + "tran", + "trigger", + "truncate", + "tsequal", + "unbounded", + "union", + "unique", + "unknown", + "unsigned", + "update", + "updating", + "user", + "using", + "validate", + "values", + "varbinary", + "varchar", + "variable", + "varying", + "view", + "wait", + "waitfor", + "when", + "where", + "while", + "window", + "with", + "with_cube", + "with_lparen", + "with_rollup", + "within", + "work", + "writetext", + ] +) class _SybaseUnitypeMixin(object): @@ -106,27 +293,28 @@ class _SybaseUnitypeMixin(object): return str(value) # decode("ucs-2") else: return None + return process class UNICHAR(_SybaseUnitypeMixin, sqltypes.Unicode): - __visit_name__ = 'UNICHAR' + __visit_name__ = "UNICHAR" class UNIVARCHAR(_SybaseUnitypeMixin, sqltypes.Unicode): - __visit_name__ = 'UNIVARCHAR' + __visit_name__ = "UNIVARCHAR" class UNITEXT(_SybaseUnitypeMixin, sqltypes.UnicodeText): - __visit_name__ = 'UNITEXT' + __visit_name__ = "UNITEXT" class TINYINT(sqltypes.Integer): - __visit_name__ = 'TINYINT' + __visit_name__ = "TINYINT" class BIT(sqltypes.TypeEngine): - __visit_name__ = 'BIT' + __visit_name__ = "BIT" class MONEY(sqltypes.TypeEngine): @@ -142,7 +330,7 @@ class UNIQUEIDENTIFIER(sqltypes.TypeEngine): class IMAGE(sqltypes.LargeBinary): - __visit_name__ = 'IMAGE' + __visit_name__ = "IMAGE" class SybaseTypeCompiler(compiler.GenericTypeCompiler): @@ -182,67 +370,66 @@ class SybaseTypeCompiler(compiler.GenericTypeCompiler): def visit_UNIQUEIDENTIFIER(self, type_, **kw): return "UNIQUEIDENTIFIER" -ischema_names = { - 'bigint': BIGINT, - 'int': INTEGER, - 'integer': INTEGER, - 'smallint': SMALLINT, - 'tinyint': TINYINT, - 'unsigned bigint': BIGINT, # TODO: unsigned flags - 'unsigned int': INTEGER, # TODO: unsigned flags - 'unsigned smallint': SMALLINT, # TODO: unsigned flags - 'numeric': NUMERIC, - 'decimal': DECIMAL, - 'dec': DECIMAL, - 'float': FLOAT, - 'double': NUMERIC, # TODO - 'double precision': NUMERIC, # TODO - 'real': REAL, - 'smallmoney': SMALLMONEY, - 'money': MONEY, - 'smalldatetime': DATETIME, - 'datetime': DATETIME, - 'date': DATE, - 'time': TIME, - 'char': CHAR, - 'character': CHAR, - 'varchar': VARCHAR, - 'character varying': VARCHAR, - 'char varying': VARCHAR, - 'unichar': UNICHAR, - 'unicode character': UNIVARCHAR, - 'nchar': NCHAR, - 'national char': NCHAR, - 'national character': NCHAR, - 'nvarchar': NVARCHAR, - 'nchar varying': NVARCHAR, - 'national char varying': NVARCHAR, - 'national character varying': NVARCHAR, - 'text': TEXT, - 'unitext': UNITEXT, - 'binary': BINARY, - 'varbinary': VARBINARY, - 'image': IMAGE, - 'bit': BIT, +ischema_names = { + "bigint": BIGINT, + "int": INTEGER, + "integer": INTEGER, + "smallint": SMALLINT, + "tinyint": TINYINT, + "unsigned bigint": BIGINT, # TODO: unsigned flags + "unsigned int": INTEGER, # TODO: unsigned flags + "unsigned smallint": SMALLINT, # TODO: unsigned flags + "numeric": NUMERIC, + "decimal": DECIMAL, + "dec": DECIMAL, + "float": FLOAT, + "double": NUMERIC, # TODO + "double precision": NUMERIC, # TODO + "real": REAL, + "smallmoney": SMALLMONEY, + "money": MONEY, + "smalldatetime": DATETIME, + "datetime": DATETIME, + "date": DATE, + "time": TIME, + "char": CHAR, + "character": CHAR, + "varchar": VARCHAR, + "character varying": VARCHAR, + "char varying": VARCHAR, + "unichar": UNICHAR, + "unicode character": UNIVARCHAR, + "nchar": NCHAR, + "national char": NCHAR, + "national character": NCHAR, + "nvarchar": NVARCHAR, + "nchar varying": NVARCHAR, + "national char varying": NVARCHAR, + "national character varying": NVARCHAR, + "text": TEXT, + "unitext": UNITEXT, + "binary": BINARY, + "varbinary": VARBINARY, + "image": IMAGE, + "bit": BIT, # not in documentation for ASE 15.7 - 'long varchar': TEXT, # TODO - 'timestamp': TIMESTAMP, - 'uniqueidentifier': UNIQUEIDENTIFIER, - + "long varchar": TEXT, # TODO + "timestamp": TIMESTAMP, + "uniqueidentifier": UNIQUEIDENTIFIER, } class SybaseInspector(reflection.Inspector): - def __init__(self, conn): reflection.Inspector.__init__(self, conn) def get_table_id(self, table_name, schema=None): """Return the table id from `table_name` and `schema`.""" - return self.dialect.get_table_id(self.bind, table_name, schema, - info_cache=self.info_cache) + return self.dialect.get_table_id( + self.bind, table_name, schema, info_cache=self.info_cache + ) class SybaseExecutionContext(default.DefaultExecutionContext): @@ -267,15 +454,17 @@ class SybaseExecutionContext(default.DefaultExecutionContext): insert_has_sequence = seq_column is not None if insert_has_sequence: - self._enable_identity_insert = \ + self._enable_identity_insert = ( seq_column.key in self.compiled_parameters[0] + ) else: self._enable_identity_insert = False if self._enable_identity_insert: self.cursor.execute( - "SET IDENTITY_INSERT %s ON" % - self.dialect.identifier_preparer.format_table(tbl)) + "SET IDENTITY_INSERT %s ON" + % self.dialect.identifier_preparer.format_table(tbl) + ) if self.isddl: # TODO: to enhance this, we can detect "ddl in tran" on the @@ -284,14 +473,16 @@ class SybaseExecutionContext(default.DefaultExecutionContext): if not self.should_autocommit: raise exc.InvalidRequestError( "The Sybase dialect only supports " - "DDL in 'autocommit' mode at this time.") + "DDL in 'autocommit' mode at this time." + ) self.root_connection.engine.logger.info( - "AUTOCOMMIT (Assuming no Sybase 'ddl in tran')") + "AUTOCOMMIT (Assuming no Sybase 'ddl in tran')" + ) self.set_ddl_autocommit( - self.root_connection.connection.connection, - True) + self.root_connection.connection.connection, True + ) def post_exec(self): if self.isddl: @@ -299,9 +490,10 @@ class SybaseExecutionContext(default.DefaultExecutionContext): if self._enable_identity_insert: self.cursor.execute( - "SET IDENTITY_INSERT %s OFF" % - self.dialect.identifier_preparer. - format_table(self.compiled.statement.table) + "SET IDENTITY_INSERT %s OFF" + % self.dialect.identifier_preparer.format_table( + self.compiled.statement.table + ) ) def get_lastrowid(self): @@ -317,11 +509,8 @@ class SybaseSQLCompiler(compiler.SQLCompiler): extract_map = util.update_copy( compiler.SQLCompiler.extract_map, - { - 'doy': 'dayofyear', - 'dow': 'weekday', - 'milliseconds': 'millisecond' - }) + {"doy": "dayofyear", "dow": "weekday", "milliseconds": "millisecond"}, + ) def get_select_precolumns(self, select, **kw): s = select._distinct and "DISTINCT " or "" @@ -330,9 +519,9 @@ class SybaseSQLCompiler(compiler.SQLCompiler): limit = select._limit if limit: # if select._limit == 1: - # s += "FIRST " + # s += "FIRST " # else: - # s += "TOP %s " % (select._limit,) + # s += "TOP %s " % (select._limit,) s += "TOP %s " % (limit,) offset = select._offset if offset: @@ -348,8 +537,7 @@ class SybaseSQLCompiler(compiler.SQLCompiler): def visit_extract(self, extract, **kw): field = self.extract_map.get(extract.field, extract.field) - return 'DATEPART("%s", %s)' % ( - field, self.process(extract.expr, **kw)) + return 'DATEPART("%s", %s)' % (field, self.process(extract.expr, **kw)) def visit_now_func(self, fn, **kw): return "GETDATE()" @@ -357,10 +545,10 @@ class SybaseSQLCompiler(compiler.SQLCompiler): def for_update_clause(self, select): # "FOR UPDATE" is only allowed on "DECLARE CURSOR" # which SQLAlchemy doesn't use - return '' + return "" def order_by_clause(self, select, **kw): - kw['literal_binds'] = True + kw["literal_binds"] = True order_by = self.process(select._order_by_clause, **kw) # SybaseSQL only allows ORDER BY in subqueries if there is a LIMIT @@ -369,8 +557,7 @@ class SybaseSQLCompiler(compiler.SQLCompiler): else: return "" - def delete_table_clause(self, delete_stmt, from_table, - extra_froms): + def delete_table_clause(self, delete_stmt, from_table, extra_froms): """If we have extra froms make sure we render any alias as hint.""" ashint = False if extra_froms: @@ -379,34 +566,41 @@ class SybaseSQLCompiler(compiler.SQLCompiler): self, asfrom=True, iscrud=True, ashint=ashint ) - def delete_extra_from_clause(self, delete_stmt, from_table, - extra_froms, from_hints, **kw): + def delete_extra_from_clause( + self, delete_stmt, from_table, extra_froms, from_hints, **kw + ): """Render the DELETE .. FROM clause specific to Sybase.""" - return "FROM " + ', '.join( - t._compiler_dispatch(self, asfrom=True, - fromhints=from_hints, **kw) - for t in [from_table] + extra_froms) + return "FROM " + ", ".join( + t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw) + for t in [from_table] + extra_froms + ) class SybaseDDLCompiler(compiler.DDLCompiler): def get_column_specification(self, column, **kwargs): - colspec = self.preparer.format_column(column) + " " + \ - self.dialect.type_compiler.process( - column.type, type_expression=column) + colspec = ( + self.preparer.format_column(column) + + " " + + self.dialect.type_compiler.process( + column.type, type_expression=column + ) + ) if column.table is None: raise exc.CompileError( "The Sybase dialect requires Table-bound " - "columns in order to generate DDL") + "columns in order to generate DDL" + ) seq_col = column.table._autoincrement_column # install a IDENTITY Sequence if we have an implicit IDENTITY column if seq_col is column: - sequence = isinstance(column.default, sa_schema.Sequence) \ + sequence = ( + isinstance(column.default, sa_schema.Sequence) and column.default + ) if sequence: - start, increment = sequence.start or 1, \ - sequence.increment or 1 + start, increment = sequence.start or 1, sequence.increment or 1 else: start, increment = 1, 1 if (start, increment) == (1, 1): @@ -431,8 +625,7 @@ class SybaseDDLCompiler(compiler.DDLCompiler): index = drop.element return "\nDROP INDEX %s.%s" % ( self.preparer.quote_identifier(index.table.name), - self._prepared_index_name(drop.element, - include_schema=False) + self._prepared_index_name(drop.element, include_schema=False), ) @@ -441,7 +634,7 @@ class SybaseIdentifierPreparer(compiler.IdentifierPreparer): class SybaseDialect(default.DefaultDialect): - name = 'sybase' + name = "sybase" supports_unicode_statements = False supports_sane_rowcount = False supports_sane_multi_rowcount = False @@ -463,14 +656,18 @@ class SybaseDialect(default.DefaultDialect): def _get_default_schema_name(self, connection): return connection.scalar( - text("SELECT user_name() as user_name", - typemap={'user_name': Unicode}) + text( + "SELECT user_name() as user_name", + typemap={"user_name": Unicode}, + ) ) def initialize(self, connection): super(SybaseDialect, self).initialize(connection) - if self.server_version_info is not None and\ - self.server_version_info < (15, ): + if ( + self.server_version_info is not None + and self.server_version_info < (15,) + ): self.max_identifier_length = 30 else: self.max_identifier_length = 255 @@ -488,22 +685,24 @@ class SybaseDialect(default.DefaultDialect): if schema is None: schema = self.default_schema_name - TABLEID_SQL = text(""" + TABLEID_SQL = text( + """ SELECT o.id AS id FROM sysobjects o JOIN sysusers u ON o.uid=u.uid WHERE u.name = :schema_name AND o.name = :table_name AND o.type in ('U', 'V') - """) + """ + ) if util.py2k: if isinstance(schema, unicode): schema = schema.encode("ascii") if isinstance(table_name, unicode): table_name = table_name.encode("ascii") - result = connection.execute(TABLEID_SQL, - schema_name=schema, - table_name=table_name) + result = connection.execute( + TABLEID_SQL, schema_name=schema, table_name=table_name + ) table_id = result.scalar() if table_id is None: raise exc.NoSuchTableError(table_name) @@ -511,10 +710,12 @@ class SybaseDialect(default.DefaultDialect): @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): - table_id = self.get_table_id(connection, table_name, schema, - info_cache=kw.get("info_cache")) + table_id = self.get_table_id( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) - COLUMN_SQL = text(""" + COLUMN_SQL = text( + """ SELECT col.name AS name, t.name AS type, (col.status & 8) AS nullable, @@ -528,23 +729,47 @@ class SybaseDialect(default.DefaultDialect): WHERE col.usertype = t.usertype AND col.id = :table_id ORDER BY col.colid - """) + """ + ) results = connection.execute(COLUMN_SQL, table_id=table_id) columns = [] - for (name, type_, nullable, autoincrement, default, precision, scale, - length) in results: - col_info = self._get_column_info(name, type_, bool(nullable), - bool(autoincrement), - default, precision, scale, - length) + for ( + name, + type_, + nullable, + autoincrement, + default, + precision, + scale, + length, + ) in results: + col_info = self._get_column_info( + name, + type_, + bool(nullable), + bool(autoincrement), + default, + precision, + scale, + length, + ) columns.append(col_info) return columns - def _get_column_info(self, name, type_, nullable, autoincrement, default, - precision, scale, length): + def _get_column_info( + self, + name, + type_, + nullable, + autoincrement, + default, + precision, + scale, + length, + ): coltype = self.ischema_names.get(type_, None) @@ -565,8 +790,9 @@ class SybaseDialect(default.DefaultDialect): # if is_array: # coltype = ARRAY(coltype) else: - util.warn("Did not recognize type '%s' of column '%s'" % - (type_, name)) + util.warn( + "Did not recognize type '%s' of column '%s'" % (type_, name) + ) coltype = sqltypes.NULLTYPE if default: @@ -575,15 +801,21 @@ class SybaseDialect(default.DefaultDialect): else: default = None - column_info = dict(name=name, type=coltype, nullable=nullable, - default=default, autoincrement=autoincrement) + column_info = dict( + name=name, + type=coltype, + nullable=nullable, + default=default, + autoincrement=autoincrement, + ) return column_info @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): - table_id = self.get_table_id(connection, table_name, schema, - info_cache=kw.get("info_cache")) + table_id = self.get_table_id( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) table_cache = {} column_cache = {} @@ -591,11 +823,13 @@ class SybaseDialect(default.DefaultDialect): table_cache[table_id] = {"name": table_name, "schema": schema} - COLUMN_SQL = text(""" + COLUMN_SQL = text( + """ SELECT c.colid AS id, c.name AS name FROM syscolumns c WHERE c.id = :table_id - """) + """ + ) results = connection.execute(COLUMN_SQL, table_id=table_id) columns = {} @@ -603,7 +837,8 @@ class SybaseDialect(default.DefaultDialect): columns[col["id"]] = col["name"] column_cache[table_id] = columns - REFCONSTRAINT_SQL = text(""" + REFCONSTRAINT_SQL = text( + """ SELECT o.name AS name, r.reftabid AS reftable_id, r.keycnt AS 'count', r.fokey1 AS fokey1, r.fokey2 AS fokey2, r.fokey3 AS fokey3, @@ -621,15 +856,19 @@ class SybaseDialect(default.DefaultDialect): r.refkey16 AS refkey16 FROM sysreferences r JOIN sysobjects o on r.tableid = o.id WHERE r.tableid = :table_id - """) + """ + ) referential_constraints = connection.execute( - REFCONSTRAINT_SQL, table_id=table_id).fetchall() + REFCONSTRAINT_SQL, table_id=table_id + ).fetchall() - REFTABLE_SQL = text(""" + REFTABLE_SQL = text( + """ SELECT o.name AS name, u.name AS 'schema' FROM sysobjects o JOIN sysusers u ON o.uid = u.uid WHERE o.id = :table_id - """) + """ + ) for r in referential_constraints: reftable_id = r["reftable_id"] @@ -639,8 +878,10 @@ class SybaseDialect(default.DefaultDialect): reftable = c.fetchone() c.close() table_info = {"name": reftable["name"], "schema": None} - if (schema is not None or - reftable["schema"] != self.default_schema_name): + if ( + schema is not None + or reftable["schema"] != self.default_schema_name + ): table_info["schema"] = reftable["schema"] table_cache[reftable_id] = table_info @@ -664,7 +905,7 @@ class SybaseDialect(default.DefaultDialect): "referred_schema": reftable["schema"], "referred_table": reftable["name"], "referred_columns": referred_columns, - "name": r["name"] + "name": r["name"], } foreign_keys.append(fk_info) @@ -673,10 +914,12 @@ class SybaseDialect(default.DefaultDialect): @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): - table_id = self.get_table_id(connection, table_name, schema, - info_cache=kw.get("info_cache")) + table_id = self.get_table_id( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) - INDEX_SQL = text(""" + INDEX_SQL = text( + """ SELECT object_name(i.id) AS table_name, i.keycnt AS 'count', i.name AS name, @@ -702,7 +945,8 @@ class SybaseDialect(default.DefaultDialect): AND o.id = :table_id AND (i.status & 2048) = 0 AND i.indid BETWEEN 1 AND 254 - """) + """ + ) results = connection.execute(INDEX_SQL, table_id=table_id) indexes = [] @@ -710,19 +954,23 @@ class SybaseDialect(default.DefaultDialect): column_names = [] for i in range(1, r["count"]): column_names.append(r["col_%i" % (i,)]) - index_info = {"name": r["name"], - "unique": bool(r["unique"]), - "column_names": column_names} + index_info = { + "name": r["name"], + "unique": bool(r["unique"]), + "column_names": column_names, + } indexes.append(index_info) return indexes @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): - table_id = self.get_table_id(connection, table_name, schema, - info_cache=kw.get("info_cache")) + table_id = self.get_table_id( + connection, table_name, schema, info_cache=kw.get("info_cache") + ) - PK_SQL = text(""" + PK_SQL = text( + """ SELECT object_name(i.id) AS table_name, i.keycnt AS 'count', i.name AS name, @@ -747,7 +995,8 @@ class SybaseDialect(default.DefaultDialect): AND o.id = :table_id AND (i.status & 2048) = 2048 AND i.indid BETWEEN 1 AND 254 - """) + """ + ) results = connection.execute(PK_SQL, table_id=table_id) pks = results.fetchone() @@ -757,8 +1006,10 @@ class SybaseDialect(default.DefaultDialect): if pks: for i in range(1, pks["count"] + 1): constrained_columns.append(pks["pk_%i" % (i,)]) - return {"constrained_columns": constrained_columns, - "name": pks["name"]} + return { + "constrained_columns": constrained_columns, + "name": pks["name"], + } else: return {"constrained_columns": [], "name": None} @@ -776,12 +1027,14 @@ class SybaseDialect(default.DefaultDialect): if schema is None: schema = self.default_schema_name - TABLE_SQL = text(""" + TABLE_SQL = text( + """ SELECT o.name AS name FROM sysobjects o JOIN sysusers u ON o.uid = u.uid WHERE u.name = :schema_name AND o.type = 'U' - """) + """ + ) if util.py2k: if isinstance(schema, unicode): @@ -796,12 +1049,14 @@ class SybaseDialect(default.DefaultDialect): if schema is None: schema = self.default_schema_name - VIEW_DEF_SQL = text(""" + VIEW_DEF_SQL = text( + """ SELECT c.text FROM syscomments c JOIN sysobjects o ON c.id = o.id WHERE o.name = :view_name AND o.type = 'V' - """) + """ + ) if util.py2k: if isinstance(view_name, unicode): @@ -816,12 +1071,14 @@ class SybaseDialect(default.DefaultDialect): if schema is None: schema = self.default_schema_name - VIEW_SQL = text(""" + VIEW_SQL = text( + """ SELECT o.name AS name FROM sysobjects o JOIN sysusers u ON o.uid = u.uid WHERE u.name = :schema_name AND o.type = 'V' - """) + """ + ) if util.py2k: if isinstance(schema, unicode): diff --git a/lib/sqlalchemy/dialects/sybase/mxodbc.py b/lib/sqlalchemy/dialects/sybase/mxodbc.py index ddb6b7e21..eeceb359b 100644 --- a/lib/sqlalchemy/dialects/sybase/mxodbc.py +++ b/lib/sqlalchemy/dialects/sybase/mxodbc.py @@ -30,4 +30,5 @@ class SybaseExecutionContext_mxodbc(SybaseExecutionContext): class SybaseDialect_mxodbc(MxODBCConnector, SybaseDialect): execution_ctx_cls = SybaseExecutionContext_mxodbc + dialect = SybaseDialect_mxodbc diff --git a/lib/sqlalchemy/dialects/sybase/pyodbc.py b/lib/sqlalchemy/dialects/sybase/pyodbc.py index af6469dad..a4759428c 100644 --- a/lib/sqlalchemy/dialects/sybase/pyodbc.py +++ b/lib/sqlalchemy/dialects/sybase/pyodbc.py @@ -34,8 +34,10 @@ Currently *not* supported are:: """ -from sqlalchemy.dialects.sybase.base import SybaseDialect,\ - SybaseExecutionContext +from sqlalchemy.dialects.sybase.base import ( + SybaseDialect, + SybaseExecutionContext, +) from sqlalchemy.connectors.pyodbc import PyODBCConnector from sqlalchemy import types as sqltypes, processors import decimal @@ -51,12 +53,10 @@ class _SybNumeric_pyodbc(sqltypes.Numeric): """ def bind_processor(self, dialect): - super_process = super(_SybNumeric_pyodbc, self).\ - bind_processor(dialect) + super_process = super(_SybNumeric_pyodbc, self).bind_processor(dialect) def process(value): - if self.asdecimal and \ - isinstance(value, decimal.Decimal): + if self.asdecimal and isinstance(value, decimal.Decimal): if value.adjusted() < -6: return processors.to_float(value) @@ -65,6 +65,7 @@ class _SybNumeric_pyodbc(sqltypes.Numeric): return super_process(value) else: return value + return process @@ -79,8 +80,7 @@ class SybaseExecutionContext_pyodbc(SybaseExecutionContext): class SybaseDialect_pyodbc(PyODBCConnector, SybaseDialect): execution_ctx_cls = SybaseExecutionContext_pyodbc - colspecs = { - sqltypes.Numeric: _SybNumeric_pyodbc, - } + colspecs = {sqltypes.Numeric: _SybNumeric_pyodbc} + dialect = SybaseDialect_pyodbc diff --git a/lib/sqlalchemy/dialects/sybase/pysybase.py b/lib/sqlalchemy/dialects/sybase/pysybase.py index 2168d5572..09d2cf380 100644 --- a/lib/sqlalchemy/dialects/sybase/pysybase.py +++ b/lib/sqlalchemy/dialects/sybase/pysybase.py @@ -22,8 +22,11 @@ kind at this time. """ from sqlalchemy import types as sqltypes, processors -from sqlalchemy.dialects.sybase.base import SybaseDialect, \ - SybaseExecutionContext, SybaseSQLCompiler +from sqlalchemy.dialects.sybase.base import ( + SybaseDialect, + SybaseExecutionContext, + SybaseSQLCompiler, +) class _SybNumeric(sqltypes.Numeric): @@ -35,7 +38,6 @@ class _SybNumeric(sqltypes.Numeric): class SybaseExecutionContext_pysybase(SybaseExecutionContext): - def set_ddl_autocommit(self, dbapi_connection, value): if value: # call commit() on the Sybase connection directly, @@ -58,24 +60,22 @@ class SybaseSQLCompiler_pysybase(SybaseSQLCompiler): class SybaseDialect_pysybase(SybaseDialect): - driver = 'pysybase' + driver = "pysybase" execution_ctx_cls = SybaseExecutionContext_pysybase statement_compiler = SybaseSQLCompiler_pysybase - colspecs = { - sqltypes.Numeric: _SybNumeric, - sqltypes.Float: sqltypes.Float - } + colspecs = {sqltypes.Numeric: _SybNumeric, sqltypes.Float: sqltypes.Float} @classmethod def dbapi(cls): import Sybase + return Sybase def create_connect_args(self, url): - opts = url.translate_connect_args(username='user', password='passwd') + opts = url.translate_connect_args(username="user", password="passwd") - return ([opts.pop('host')], opts) + return ([opts.pop("host")], opts) def do_executemany(self, cursor, statement, parameters, context=None): # calling python-sybase executemany yields: @@ -90,13 +90,17 @@ class SybaseDialect_pysybase(SybaseDialect): return (vers / 1000, vers % 1000 / 100, vers % 100 / 10, vers % 10) def is_disconnect(self, e, connection, cursor): - if isinstance(e, (self.dbapi.OperationalError, - self.dbapi.ProgrammingError)): + if isinstance( + e, (self.dbapi.OperationalError, self.dbapi.ProgrammingError) + ): msg = str(e) - return ('Unable to complete network request to host' in msg or - 'Invalid connection state' in msg or - 'Invalid cursor state' in msg) + return ( + "Unable to complete network request to host" in msg + or "Invalid connection state" in msg + or "Invalid cursor state" in msg + ) else: return False + dialect = SybaseDialect_pysybase |
