summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2019-01-06 01:14:26 -0500
committermike bayer <mike_mp@zzzcomputing.com>2019-01-06 17:34:50 +0000
commit1e1a38e7801f410f244e4bbb44ec795ae152e04e (patch)
tree28e725c5c8188bd0cfd133d1e268dbca9b524978 /lib/sqlalchemy/sql
parent404e69426b05a82d905cbb3ad33adafccddb00dd (diff)
downloadsqlalchemy-1e1a38e7801f410f244e4bbb44ec795ae152e04e.tar.gz
Run black -l 79 against all source files
This is a straight reformat run using black as is, with no edits applied at all. The black run will format code consistently, however in some cases that are prevalent in SQLAlchemy code it produces too-long lines. The too-long lines will be resolved in the following commit that will resolve all remaining flake8 issues including shadowed builtins, long lines, import order, unused imports, duplicate imports, and docstring issues. Change-Id: I7eda77fed3d8e73df84b3651fd6cfcfe858d4dc9
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/__init__.py11
-rw-r--r--lib/sqlalchemy/sql/annotation.py16
-rw-r--r--lib/sqlalchemy/sql/base.py114
-rw-r--r--lib/sqlalchemy/sql/compiler.py2030
-rw-r--r--lib/sqlalchemy/sql/crud.py440
-rw-r--r--lib/sqlalchemy/sql/ddl.py306
-rw-r--r--lib/sqlalchemy/sql/default_comparator.py234
-rw-r--r--lib/sqlalchemy/sql/dml.py194
-rw-r--r--lib/sqlalchemy/sql/elements.py800
-rw-r--r--lib/sqlalchemy/sql/expression.py205
-rw-r--r--lib/sqlalchemy/sql/functions.py139
-rw-r--r--lib/sqlalchemy/sql/naming.py47
-rw-r--r--lib/sqlalchemy/sql/operators.py109
-rw-r--r--lib/sqlalchemy/sql/schema.py1129
-rw-r--r--lib/sqlalchemy/sql/selectable.py812
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py649
-rw-r--r--lib/sqlalchemy/sql/type_api.py122
-rw-r--r--lib/sqlalchemy/sql/util.py327
-rw-r--r--lib/sqlalchemy/sql/visitors.py48
19 files changed, 4603 insertions, 3129 deletions
diff --git a/lib/sqlalchemy/sql/__init__.py b/lib/sqlalchemy/sql/__init__.py
index aa811388b..87e2fb6c3 100644
--- a/lib/sqlalchemy/sql/__init__.py
+++ b/lib/sqlalchemy/sql/__init__.py
@@ -72,7 +72,7 @@ from .expression import (
union,
union_all,
update,
- within_group
+ within_group,
)
from .visitors import ClauseVisitor
@@ -84,12 +84,16 @@ def __go(lcls):
import inspect as _inspect
- __all__ = sorted(name for name, obj in lcls.items()
- if not (name.startswith('_') or _inspect.ismodule(obj)))
+ __all__ = sorted(
+ name
+ for name, obj in lcls.items()
+ if not (name.startswith("_") or _inspect.ismodule(obj))
+ )
from .annotation import _prepare_annotations, Annotated
from .elements import AnnotatedColumnElement, ClauseList
from .selectable import AnnotatedFromClause
+
_prepare_annotations(ColumnElement, AnnotatedColumnElement)
_prepare_annotations(FromClause, AnnotatedFromClause)
_prepare_annotations(ClauseList, Annotated)
@@ -98,4 +102,5 @@ def __go(lcls):
from . import naming
+
__go(locals())
diff --git a/lib/sqlalchemy/sql/annotation.py b/lib/sqlalchemy/sql/annotation.py
index c1d484d95..64cfa630e 100644
--- a/lib/sqlalchemy/sql/annotation.py
+++ b/lib/sqlalchemy/sql/annotation.py
@@ -76,8 +76,7 @@ class Annotated(object):
return self._with_annotations(_values)
def _compiler_dispatch(self, visitor, **kw):
- return self.__element.__class__._compiler_dispatch(
- self, visitor, **kw)
+ return self.__element.__class__._compiler_dispatch(self, visitor, **kw)
@property
def _constructor(self):
@@ -120,10 +119,13 @@ def _deep_annotate(element, annotations, exclude=None):
Elements within the exclude collection will be cloned but not annotated.
"""
+
def clone(elem):
- if exclude and \
- hasattr(elem, 'proxy_set') and \
- elem.proxy_set.intersection(exclude):
+ if (
+ exclude
+ and hasattr(elem, "proxy_set")
+ and elem.proxy_set.intersection(exclude)
+ ):
newelem = elem._clone()
elif annotations != elem._annotations:
newelem = elem._annotate(annotations)
@@ -191,8 +193,8 @@ def _new_annotation_type(cls, base_cls):
break
annotated_classes[cls] = anno_cls = type(
- "Annotated%s" % cls.__name__,
- (base_cls, cls), {})
+ "Annotated%s" % cls.__name__, (base_cls, cls), {}
+ )
globals()["Annotated%s" % cls.__name__] = anno_cls
return anno_cls
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 6b9b55753..45db215fe 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -15,8 +15,8 @@ import itertools
from .visitors import ClauseVisitor
import re
-PARSE_AUTOCOMMIT = util.symbol('PARSE_AUTOCOMMIT')
-NO_ARG = util.symbol('NO_ARG')
+PARSE_AUTOCOMMIT = util.symbol("PARSE_AUTOCOMMIT")
+NO_ARG = util.symbol("NO_ARG")
class Immutable(object):
@@ -77,7 +77,8 @@ class _DialectArgView(util.collections_abc.MutableMapping):
dialect, value_key = self._key(key)
except KeyError:
raise exc.ArgumentError(
- "Keys must be of the form <dialectname>_<argname>")
+ "Keys must be of the form <dialectname>_<argname>"
+ )
else:
self.obj.dialect_options[dialect][value_key] = value
@@ -86,15 +87,18 @@ class _DialectArgView(util.collections_abc.MutableMapping):
del self.obj.dialect_options[dialect][value_key]
def __len__(self):
- return sum(len(args._non_defaults) for args in
- self.obj.dialect_options.values())
+ return sum(
+ len(args._non_defaults)
+ for args in self.obj.dialect_options.values()
+ )
def __iter__(self):
return (
util.safe_kwarg("%s_%s" % (dialect_name, value_name))
for dialect_name in self.obj.dialect_options
- for value_name in
- self.obj.dialect_options[dialect_name]._non_defaults
+ for value_name in self.obj.dialect_options[
+ dialect_name
+ ]._non_defaults
)
@@ -187,8 +191,8 @@ class DialectKWArgs(object):
if construct_arg_dictionary is None:
raise exc.ArgumentError(
"Dialect '%s' does have keyword-argument "
- "validation and defaults enabled configured" %
- dialect_name)
+ "validation and defaults enabled configured" % dialect_name
+ )
if cls not in construct_arg_dictionary:
construct_arg_dictionary[cls] = {}
construct_arg_dictionary[cls][argument_name] = default
@@ -230,6 +234,7 @@ class DialectKWArgs(object):
if dialect_cls.construct_arguments is None:
return None
return dict(dialect_cls.construct_arguments)
+
_kw_registry = util.PopulateDict(_kw_reg_for_dialect)
def _kw_reg_for_dialect_cls(self, dialect_name):
@@ -274,11 +279,12 @@ class DialectKWArgs(object):
return
for k in kwargs:
- m = re.match('^(.+?)_(.+)$', k)
+ m = re.match("^(.+?)_(.+)$", k)
if not m:
raise TypeError(
"Additional arguments should be "
- "named <dialectname>_<argument>, got '%s'" % k)
+ "named <dialectname>_<argument>, got '%s'" % k
+ )
dialect_name, arg_name = m.group(1, 2)
try:
@@ -286,20 +292,22 @@ class DialectKWArgs(object):
except exc.NoSuchModuleError:
util.warn(
"Can't validate argument %r; can't "
- "locate any SQLAlchemy dialect named %r" %
- (k, dialect_name))
+ "locate any SQLAlchemy dialect named %r"
+ % (k, dialect_name)
+ )
self.dialect_options[dialect_name] = d = _DialectArgDict()
d._defaults.update({"*": None})
d._non_defaults[arg_name] = kwargs[k]
else:
- if "*" not in construct_arg_dictionary and \
- arg_name not in construct_arg_dictionary:
+ if (
+ "*" not in construct_arg_dictionary
+ and arg_name not in construct_arg_dictionary
+ ):
raise exc.ArgumentError(
"Argument %r is not accepted by "
- "dialect %r on behalf of %r" % (
- k,
- dialect_name, self.__class__
- ))
+ "dialect %r on behalf of %r"
+ % (k, dialect_name, self.__class__)
+ )
else:
construct_arg_dictionary[arg_name] = kwargs[k]
@@ -359,14 +367,14 @@ class Executable(Generative):
:meth:`.Query.execution_options()`
"""
- if 'isolation_level' in kw:
+ if "isolation_level" in kw:
raise exc.ArgumentError(
"'isolation_level' execution option may only be specified "
"on Connection.execution_options(), or "
"per-engine using the isolation_level "
"argument to create_engine()."
)
- if 'compiled_cache' in kw:
+ if "compiled_cache" in kw:
raise exc.ArgumentError(
"'compiled_cache' execution option may only be specified "
"on Connection.execution_options(), not per statement."
@@ -377,10 +385,12 @@ class Executable(Generative):
"""Compile and execute this :class:`.Executable`."""
e = self.bind
if e is None:
- label = getattr(self, 'description', self.__class__.__name__)
- msg = ('This %s is not directly bound to a Connection or Engine. '
- 'Use the .execute() method of a Connection or Engine '
- 'to execute this construct.' % label)
+ label = getattr(self, "description", self.__class__.__name__)
+ msg = (
+ "This %s is not directly bound to a Connection or Engine. "
+ "Use the .execute() method of a Connection or Engine "
+ "to execute this construct." % label
+ )
raise exc.UnboundExecutionError(msg)
return e._execute_clauseelement(self, multiparams, params)
@@ -434,7 +444,7 @@ class SchemaEventTarget(object):
class SchemaVisitor(ClauseVisitor):
"""Define the visiting for ``SchemaItem`` objects."""
- __traverse_options__ = {'schema_visitor': True}
+ __traverse_options__ = {"schema_visitor": True}
class ColumnCollection(util.OrderedProperties):
@@ -446,11 +456,11 @@ class ColumnCollection(util.OrderedProperties):
"""
- __slots__ = '_all_columns'
+ __slots__ = "_all_columns"
def __init__(self, *columns):
super(ColumnCollection, self).__init__()
- object.__setattr__(self, '_all_columns', [])
+ object.__setattr__(self, "_all_columns", [])
for c in columns:
self.add(c)
@@ -485,8 +495,9 @@ class ColumnCollection(util.OrderedProperties):
self._data[column.key] = column
if remove_col is not None:
- self._all_columns[:] = [column if c is remove_col
- else c for c in self._all_columns]
+ self._all_columns[:] = [
+ column if c is remove_col else c for c in self._all_columns
+ ]
else:
self._all_columns.append(column)
@@ -499,7 +510,8 @@ class ColumnCollection(util.OrderedProperties):
"""
if not column.key:
raise exc.ArgumentError(
- "Can't add unnamed column to column collection")
+ "Can't add unnamed column to column collection"
+ )
self[column.key] = column
def __delitem__(self, key):
@@ -521,10 +533,12 @@ class ColumnCollection(util.OrderedProperties):
return
if not existing.shares_lineage(value):
- util.warn('Column %r on table %r being replaced by '
- '%r, which has the same key. Consider '
- 'use_labels for select() statements.' %
- (key, getattr(existing, 'table', None), value))
+ util.warn(
+ "Column %r on table %r being replaced by "
+ "%r, which has the same key. Consider "
+ "use_labels for select() statements."
+ % (key, getattr(existing, "table", None), value)
+ )
# pop out memoized proxy_set as this
# operation may very well be occurring
@@ -540,13 +554,15 @@ class ColumnCollection(util.OrderedProperties):
def remove(self, column):
del self._data[column.key]
self._all_columns[:] = [
- c for c in self._all_columns if c is not column]
+ c for c in self._all_columns if c is not column
+ ]
def update(self, iter):
cols = list(iter)
all_col_set = set(self._all_columns)
self._all_columns.extend(
- c for label, c in cols if c not in all_col_set)
+ c for label, c in cols if c not in all_col_set
+ )
self._data.update((label, c) for label, c in cols)
def extend(self, iter):
@@ -572,12 +588,11 @@ class ColumnCollection(util.OrderedProperties):
return util.OrderedProperties.__contains__(self, other)
def __getstate__(self):
- return {'_data': self._data,
- '_all_columns': self._all_columns}
+ return {"_data": self._data, "_all_columns": self._all_columns}
def __setstate__(self, state):
- object.__setattr__(self, '_data', state['_data'])
- object.__setattr__(self, '_all_columns', state['_all_columns'])
+ object.__setattr__(self, "_data", state["_data"])
+ object.__setattr__(self, "_all_columns", state["_all_columns"])
def contains_column(self, col):
return col in set(self._all_columns)
@@ -589,7 +604,7 @@ class ColumnCollection(util.OrderedProperties):
class ImmutableColumnCollection(util.ImmutableProperties, ColumnCollection):
def __init__(self, data, all_columns):
util.ImmutableProperties.__init__(self, data)
- object.__setattr__(self, '_all_columns', all_columns)
+ object.__setattr__(self, "_all_columns", all_columns)
extend = remove = util.ImmutableProperties._immutable
@@ -622,15 +637,18 @@ def _bind_or_error(schemaitem, msg=None):
bind = schemaitem.bind
if not bind:
name = schemaitem.__class__.__name__
- label = getattr(schemaitem, 'fullname',
- getattr(schemaitem, 'name', None))
+ label = getattr(
+ schemaitem, "fullname", getattr(schemaitem, "name", None)
+ )
if label:
- item = '%s object %r' % (name, label)
+ item = "%s object %r" % (name, label)
else:
- item = '%s object' % name
+ item = "%s object" % name
if msg is None:
- msg = "%s is not bound to an Engine or Connection. "\
- "Execution can not proceed without a database to execute "\
+ msg = (
+ "%s is not bound to an Engine or Connection. "
+ "Execution can not proceed without a database to execute "
"against." % item
+ )
raise exc.UnboundExecutionError(msg)
return bind
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 80ed707ed..f641d0a84 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -25,133 +25,218 @@ To generate user-defined SQL strings, see
import contextlib
import re
-from . import schema, sqltypes, operators, functions, visitors, \
- elements, selectable, crud
+from . import (
+ schema,
+ sqltypes,
+ operators,
+ functions,
+ visitors,
+ elements,
+ selectable,
+ crud,
+)
from .. import util, exc
import itertools
-RESERVED_WORDS = set([
- 'all', 'analyse', 'analyze', 'and', 'any', 'array',
- 'as', 'asc', 'asymmetric', 'authorization', 'between',
- 'binary', 'both', 'case', 'cast', 'check', 'collate',
- 'column', 'constraint', 'create', 'cross', 'current_date',
- 'current_role', 'current_time', 'current_timestamp',
- 'current_user', 'default', 'deferrable', 'desc',
- 'distinct', 'do', 'else', 'end', 'except', 'false',
- 'for', 'foreign', 'freeze', 'from', 'full', 'grant',
- 'group', 'having', 'ilike', 'in', 'initially', 'inner',
- 'intersect', 'into', 'is', 'isnull', 'join', 'leading',
- 'left', 'like', 'limit', 'localtime', 'localtimestamp',
- 'natural', 'new', 'not', 'notnull', 'null', 'off', 'offset',
- 'old', 'on', 'only', 'or', 'order', 'outer', 'overlaps',
- 'placing', 'primary', 'references', 'right', 'select',
- 'session_user', 'set', 'similar', 'some', 'symmetric', 'table',
- 'then', 'to', 'trailing', 'true', 'union', 'unique', 'user',
- 'using', 'verbose', 'when', 'where'])
-
-LEGAL_CHARACTERS = re.compile(r'^[A-Z0-9_$]+$', re.I)
-ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(['$'])
-
-BIND_PARAMS = re.compile(r'(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])', re.UNICODE)
-BIND_PARAMS_ESC = re.compile(r'\x5c(:[\w\$]*)(?![:\w\$])', re.UNICODE)
+RESERVED_WORDS = set(
+ [
+ "all",
+ "analyse",
+ "analyze",
+ "and",
+ "any",
+ "array",
+ "as",
+ "asc",
+ "asymmetric",
+ "authorization",
+ "between",
+ "binary",
+ "both",
+ "case",
+ "cast",
+ "check",
+ "collate",
+ "column",
+ "constraint",
+ "create",
+ "cross",
+ "current_date",
+ "current_role",
+ "current_time",
+ "current_timestamp",
+ "current_user",
+ "default",
+ "deferrable",
+ "desc",
+ "distinct",
+ "do",
+ "else",
+ "end",
+ "except",
+ "false",
+ "for",
+ "foreign",
+ "freeze",
+ "from",
+ "full",
+ "grant",
+ "group",
+ "having",
+ "ilike",
+ "in",
+ "initially",
+ "inner",
+ "intersect",
+ "into",
+ "is",
+ "isnull",
+ "join",
+ "leading",
+ "left",
+ "like",
+ "limit",
+ "localtime",
+ "localtimestamp",
+ "natural",
+ "new",
+ "not",
+ "notnull",
+ "null",
+ "off",
+ "offset",
+ "old",
+ "on",
+ "only",
+ "or",
+ "order",
+ "outer",
+ "overlaps",
+ "placing",
+ "primary",
+ "references",
+ "right",
+ "select",
+ "session_user",
+ "set",
+ "similar",
+ "some",
+ "symmetric",
+ "table",
+ "then",
+ "to",
+ "trailing",
+ "true",
+ "union",
+ "unique",
+ "user",
+ "using",
+ "verbose",
+ "when",
+ "where",
+ ]
+)
+
+LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I)
+ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"])
+
+BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
+BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
BIND_TEMPLATES = {
- 'pyformat': "%%(%(name)s)s",
- 'qmark': "?",
- 'format': "%%s",
- 'numeric': ":[_POSITION]",
- 'named': ":%(name)s"
+ "pyformat": "%%(%(name)s)s",
+ "qmark": "?",
+ "format": "%%s",
+ "numeric": ":[_POSITION]",
+ "named": ":%(name)s",
}
OPERATORS = {
# binary
- operators.and_: ' AND ',
- operators.or_: ' OR ',
- operators.add: ' + ',
- operators.mul: ' * ',
- operators.sub: ' - ',
- operators.div: ' / ',
- operators.mod: ' % ',
- operators.truediv: ' / ',
- operators.neg: '-',
- operators.lt: ' < ',
- operators.le: ' <= ',
- operators.ne: ' != ',
- operators.gt: ' > ',
- operators.ge: ' >= ',
- operators.eq: ' = ',
- operators.is_distinct_from: ' IS DISTINCT FROM ',
- operators.isnot_distinct_from: ' IS NOT DISTINCT FROM ',
- operators.concat_op: ' || ',
- operators.match_op: ' MATCH ',
- operators.notmatch_op: ' NOT MATCH ',
- operators.in_op: ' IN ',
- operators.notin_op: ' NOT IN ',
- operators.comma_op: ', ',
- operators.from_: ' FROM ',
- operators.as_: ' AS ',
- operators.is_: ' IS ',
- operators.isnot: ' IS NOT ',
- operators.collate: ' COLLATE ',
-
+ operators.and_: " AND ",
+ operators.or_: " OR ",
+ operators.add: " + ",
+ operators.mul: " * ",
+ operators.sub: " - ",
+ operators.div: " / ",
+ operators.mod: " % ",
+ operators.truediv: " / ",
+ operators.neg: "-",
+ operators.lt: " < ",
+ operators.le: " <= ",
+ operators.ne: " != ",
+ operators.gt: " > ",
+ operators.ge: " >= ",
+ operators.eq: " = ",
+ operators.is_distinct_from: " IS DISTINCT FROM ",
+ operators.isnot_distinct_from: " IS NOT DISTINCT FROM ",
+ operators.concat_op: " || ",
+ operators.match_op: " MATCH ",
+ operators.notmatch_op: " NOT MATCH ",
+ operators.in_op: " IN ",
+ operators.notin_op: " NOT IN ",
+ operators.comma_op: ", ",
+ operators.from_: " FROM ",
+ operators.as_: " AS ",
+ operators.is_: " IS ",
+ operators.isnot: " IS NOT ",
+ operators.collate: " COLLATE ",
# unary
- operators.exists: 'EXISTS ',
- operators.distinct_op: 'DISTINCT ',
- operators.inv: 'NOT ',
- operators.any_op: 'ANY ',
- operators.all_op: 'ALL ',
-
+ operators.exists: "EXISTS ",
+ operators.distinct_op: "DISTINCT ",
+ operators.inv: "NOT ",
+ operators.any_op: "ANY ",
+ operators.all_op: "ALL ",
# modifiers
- operators.desc_op: ' DESC',
- operators.asc_op: ' ASC',
- operators.nullsfirst_op: ' NULLS FIRST',
- operators.nullslast_op: ' NULLS LAST',
-
+ operators.desc_op: " DESC",
+ operators.asc_op: " ASC",
+ operators.nullsfirst_op: " NULLS FIRST",
+ operators.nullslast_op: " NULLS LAST",
}
FUNCTIONS = {
- functions.coalesce: 'coalesce',
- functions.current_date: 'CURRENT_DATE',
- functions.current_time: 'CURRENT_TIME',
- functions.current_timestamp: 'CURRENT_TIMESTAMP',
- functions.current_user: 'CURRENT_USER',
- functions.localtime: 'LOCALTIME',
- functions.localtimestamp: 'LOCALTIMESTAMP',
- functions.random: 'random',
- functions.sysdate: 'sysdate',
- functions.session_user: 'SESSION_USER',
- functions.user: 'USER',
- functions.cube: 'CUBE',
- functions.rollup: 'ROLLUP',
- functions.grouping_sets: 'GROUPING SETS',
+ functions.coalesce: "coalesce",
+ functions.current_date: "CURRENT_DATE",
+ functions.current_time: "CURRENT_TIME",
+ functions.current_timestamp: "CURRENT_TIMESTAMP",
+ functions.current_user: "CURRENT_USER",
+ functions.localtime: "LOCALTIME",
+ functions.localtimestamp: "LOCALTIMESTAMP",
+ functions.random: "random",
+ functions.sysdate: "sysdate",
+ functions.session_user: "SESSION_USER",
+ functions.user: "USER",
+ functions.cube: "CUBE",
+ functions.rollup: "ROLLUP",
+ functions.grouping_sets: "GROUPING SETS",
}
EXTRACT_MAP = {
- 'month': 'month',
- 'day': 'day',
- 'year': 'year',
- 'second': 'second',
- 'hour': 'hour',
- 'doy': 'doy',
- 'minute': 'minute',
- 'quarter': 'quarter',
- 'dow': 'dow',
- 'week': 'week',
- 'epoch': 'epoch',
- 'milliseconds': 'milliseconds',
- 'microseconds': 'microseconds',
- 'timezone_hour': 'timezone_hour',
- 'timezone_minute': 'timezone_minute'
+ "month": "month",
+ "day": "day",
+ "year": "year",
+ "second": "second",
+ "hour": "hour",
+ "doy": "doy",
+ "minute": "minute",
+ "quarter": "quarter",
+ "dow": "dow",
+ "week": "week",
+ "epoch": "epoch",
+ "milliseconds": "milliseconds",
+ "microseconds": "microseconds",
+ "timezone_hour": "timezone_hour",
+ "timezone_minute": "timezone_minute",
}
COMPOUND_KEYWORDS = {
- selectable.CompoundSelect.UNION: 'UNION',
- selectable.CompoundSelect.UNION_ALL: 'UNION ALL',
- selectable.CompoundSelect.EXCEPT: 'EXCEPT',
- selectable.CompoundSelect.EXCEPT_ALL: 'EXCEPT ALL',
- selectable.CompoundSelect.INTERSECT: 'INTERSECT',
- selectable.CompoundSelect.INTERSECT_ALL: 'INTERSECT ALL'
+ selectable.CompoundSelect.UNION: "UNION",
+ selectable.CompoundSelect.UNION_ALL: "UNION ALL",
+ selectable.CompoundSelect.EXCEPT: "EXCEPT",
+ selectable.CompoundSelect.EXCEPT_ALL: "EXCEPT ALL",
+ selectable.CompoundSelect.INTERSECT: "INTERSECT",
+ selectable.CompoundSelect.INTERSECT_ALL: "INTERSECT ALL",
}
@@ -177,9 +262,14 @@ class Compiled(object):
sub-elements of the statement can modify these.
"""
- def __init__(self, dialect, statement, bind=None,
- schema_translate_map=None,
- compile_kwargs=util.immutabledict()):
+ def __init__(
+ self,
+ dialect,
+ statement,
+ bind=None,
+ schema_translate_map=None,
+ compile_kwargs=util.immutabledict(),
+ ):
"""Construct a new :class:`.Compiled` object.
:param dialect: :class:`.Dialect` to compile against.
@@ -209,7 +299,8 @@ class Compiled(object):
self.preparer = self.dialect.identifier_preparer
if schema_translate_map:
self.preparer = self.preparer._with_schema_translate(
- schema_translate_map)
+ schema_translate_map
+ )
if statement is not None:
self.statement = statement
@@ -218,8 +309,10 @@ class Compiled(object):
self.execution_options = statement._execution_options
self.string = self.process(self.statement, **compile_kwargs)
- @util.deprecated("0.7", ":class:`.Compiled` objects now compile "
- "within the constructor.")
+ @util.deprecated(
+ "0.7",
+ ":class:`.Compiled` objects now compile " "within the constructor.",
+ )
def compile(self):
"""Produce the internal string representation of this element.
"""
@@ -247,7 +340,7 @@ class Compiled(object):
def __str__(self):
"""Return the string text of the generated SQL or DDL."""
- return self.string or ''
+ return self.string or ""
def construct_params(self, params=None):
"""Return the bind params for this compiled object.
@@ -271,7 +364,9 @@ class Compiled(object):
if e is None:
raise exc.UnboundExecutionError(
"This Compiled object is not bound to any Engine "
- "or Connection.", code="2afi")
+ "or Connection.",
+ code="2afi",
+ )
return e._execute_compiled(self, multiparams, params)
def scalar(self, *multiparams, **params):
@@ -284,7 +379,7 @@ class Compiled(object):
class TypeCompiler(util.with_metaclass(util.EnsureKWArgType, object)):
"""Produces DDL specification for TypeEngine objects."""
- ensure_kwarg = r'visit_\w+'
+ ensure_kwarg = r"visit_\w+"
def __init__(self, dialect):
self.dialect = dialect
@@ -297,8 +392,8 @@ class _CompileLabel(visitors.Visitable):
"""lightweight label object which acts as an expression.Label."""
- __visit_name__ = 'label'
- __slots__ = 'element', 'name'
+ __visit_name__ = "label"
+ __slots__ = "element", "name"
def __init__(self, col, name, alt_names=()):
self.element = col
@@ -390,8 +485,9 @@ class SQLCompiler(Compiled):
insert_prefetch = update_prefetch = ()
- def __init__(self, dialect, statement, column_keys=None,
- inline=False, **kwargs):
+ def __init__(
+ self, dialect, statement, column_keys=None, inline=False, **kwargs
+ ):
"""Construct a new :class:`.SQLCompiler` object.
:param dialect: :class:`.Dialect` to be used
@@ -412,7 +508,7 @@ class SQLCompiler(Compiled):
# compile INSERT/UPDATE defaults/sequences inlined (no pre-
# execute)
- self.inline = inline or getattr(statement, 'inline', False)
+ self.inline = inline or getattr(statement, "inline", False)
# a dictionary of bind parameter keys to BindParameter
# instances.
@@ -440,8 +536,9 @@ class SQLCompiler(Compiled):
self.ctes = None
- self.label_length = dialect.label_length \
- or dialect.max_identifier_length
+ self.label_length = (
+ dialect.label_length or dialect.max_identifier_length
+ )
# a map which tracks "anonymous" identifiers that are created on
# the fly here
@@ -453,7 +550,7 @@ class SQLCompiler(Compiled):
Compiled.__init__(self, dialect, statement, **kwargs)
if (
- self.isinsert or self.isupdate or self.isdelete
+ self.isinsert or self.isupdate or self.isdelete
) and statement._returning:
self.returning = statement._returning
@@ -482,37 +579,43 @@ class SQLCompiler(Compiled):
def _nested_result(self):
"""special API to support the use case of 'nested result sets'"""
result_columns, ordered_columns = (
- self._result_columns, self._ordered_columns)
+ self._result_columns,
+ self._ordered_columns,
+ )
self._result_columns, self._ordered_columns = [], False
try:
if self.stack:
entry = self.stack[-1]
- entry['need_result_map_for_nested'] = True
+ entry["need_result_map_for_nested"] = True
else:
entry = None
yield self._result_columns, self._ordered_columns
finally:
if entry:
- entry.pop('need_result_map_for_nested')
+ entry.pop("need_result_map_for_nested")
self._result_columns, self._ordered_columns = (
- result_columns, ordered_columns)
+ result_columns,
+ ordered_columns,
+ )
def _apply_numbered_params(self):
poscount = itertools.count(1)
self.string = re.sub(
- r'\[_POSITION\]',
- lambda m: str(util.next(poscount)),
- self.string)
+ r"\[_POSITION\]", lambda m: str(util.next(poscount)), self.string
+ )
@util.memoized_property
def _bind_processors(self):
return dict(
- (key, value) for key, value in
- ((self.bind_names[bindparam],
- bindparam.type._cached_bind_processor(self.dialect)
- )
- for bindparam in self.bind_names)
+ (key, value)
+ for key, value in (
+ (
+ self.bind_names[bindparam],
+ bindparam.type._cached_bind_processor(self.dialect),
+ )
+ for bindparam in self.bind_names
+ )
if value is not None
)
@@ -539,12 +642,16 @@ class SQLCompiler(Compiled):
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
- "in parameter group %d" %
- (bindparam.key, _group_number), code="cd3x")
+ "in parameter group %d"
+ % (bindparam.key, _group_number),
+ code="cd3x",
+ )
else:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r"
- % bindparam.key, code="cd3x")
+ % bindparam.key,
+ code="cd3x",
+ )
elif bindparam.callable:
pd[name] = bindparam.effective_value
@@ -558,12 +665,16 @@ class SQLCompiler(Compiled):
if _group_number:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r, "
- "in parameter group %d" %
- (bindparam.key, _group_number), code="cd3x")
+ "in parameter group %d"
+ % (bindparam.key, _group_number),
+ code="cd3x",
+ )
else:
raise exc.InvalidRequestError(
"A value is required for bind parameter %r"
- % bindparam.key, code="cd3x")
+ % bindparam.key,
+ code="cd3x",
+ )
if bindparam.callable:
pd[self.bind_names[bindparam]] = bindparam.effective_value
@@ -595,9 +706,10 @@ class SQLCompiler(Compiled):
return "(" + grouping.element._compiler_dispatch(self, **kwargs) + ")"
def visit_label_reference(
- self, element, within_columns_clause=False, **kwargs):
+ self, element, within_columns_clause=False, **kwargs
+ ):
if self.stack and self.dialect.supports_simple_order_by_label:
- selectable = self.stack[-1]['selectable']
+ selectable = self.stack[-1]["selectable"]
with_cols, only_froms, only_cols = selectable._label_resolve_dict
if within_columns_clause:
@@ -611,25 +723,30 @@ class SQLCompiler(Compiled):
# to something else like a ColumnClause expression.
order_by_elem = element.element._order_by_label_element
- if order_by_elem is not None and order_by_elem.name in \
- resolve_dict and \
- order_by_elem.shares_lineage(
- resolve_dict[order_by_elem.name]):
- kwargs['render_label_as_label'] = \
- element.element._order_by_label_element
+ if (
+ order_by_elem is not None
+ and order_by_elem.name in resolve_dict
+ and order_by_elem.shares_lineage(
+ resolve_dict[order_by_elem.name]
+ )
+ ):
+ kwargs[
+ "render_label_as_label"
+ ] = element.element._order_by_label_element
return self.process(
- element.element, within_columns_clause=within_columns_clause,
- **kwargs)
+ element.element,
+ within_columns_clause=within_columns_clause,
+ **kwargs
+ )
def visit_textual_label_reference(
- self, element, within_columns_clause=False, **kwargs):
+ self, element, within_columns_clause=False, **kwargs
+ ):
if not self.stack:
# compiling the element outside of the context of a SELECT
- return self.process(
- element._text_clause
- )
+ return self.process(element._text_clause)
- selectable = self.stack[-1]['selectable']
+ selectable = self.stack[-1]["selectable"]
with_cols, only_froms, only_cols = selectable._label_resolve_dict
try:
if within_columns_clause:
@@ -640,26 +757,30 @@ class SQLCompiler(Compiled):
# treat it like text()
util.warn_limited(
"Can't resolve label reference %r; converting to text()",
- util.ellipses_string(element.element))
- return self.process(
- element._text_clause
+ util.ellipses_string(element.element),
)
+ return self.process(element._text_clause)
else:
- kwargs['render_label_as_label'] = col
+ kwargs["render_label_as_label"] = col
return self.process(
- col, within_columns_clause=within_columns_clause, **kwargs)
-
- def visit_label(self, label,
- add_to_result_map=None,
- within_label_clause=False,
- within_columns_clause=False,
- render_label_as_label=None,
- **kw):
+ col, within_columns_clause=within_columns_clause, **kwargs
+ )
+
+ def visit_label(
+ self,
+ label,
+ add_to_result_map=None,
+ within_label_clause=False,
+ within_columns_clause=False,
+ render_label_as_label=None,
+ **kw
+ ):
# only render labels within the columns clause
# or ORDER BY clause of a select. dialect-specific compilers
# can modify this behavior.
- render_label_with_as = (within_columns_clause and not
- within_label_clause)
+ render_label_with_as = (
+ within_columns_clause and not within_label_clause
+ )
render_label_only = render_label_as_label is label
if render_label_only or render_label_with_as:
@@ -673,27 +794,35 @@ class SQLCompiler(Compiled):
add_to_result_map(
labelname,
label.name,
- (label, labelname, ) + label._alt_names,
- label.type
+ (label, labelname) + label._alt_names,
+ label.type,
)
- return label.element._compiler_dispatch(
- self, within_columns_clause=True,
- within_label_clause=True, **kw) + \
- OPERATORS[operators.as_] + \
- self.preparer.format_label(label, labelname)
+ return (
+ label.element._compiler_dispatch(
+ self,
+ within_columns_clause=True,
+ within_label_clause=True,
+ **kw
+ )
+ + OPERATORS[operators.as_]
+ + self.preparer.format_label(label, labelname)
+ )
elif render_label_only:
return self.preparer.format_label(label, labelname)
else:
return label.element._compiler_dispatch(
- self, within_columns_clause=False, **kw)
+ self, within_columns_clause=False, **kw
+ )
def _fallback_column_name(self, column):
- raise exc.CompileError("Cannot compile Column object until "
- "its 'name' is assigned.")
+ raise exc.CompileError(
+ "Cannot compile Column object until " "its 'name' is assigned."
+ )
- def visit_column(self, column, add_to_result_map=None,
- include_table=True, **kwargs):
+ def visit_column(
+ self, column, add_to_result_map=None, include_table=True, **kwargs
+ ):
name = orig_name = column.name
if name is None:
name = self._fallback_column_name(column)
@@ -704,10 +833,7 @@ class SQLCompiler(Compiled):
if add_to_result_map is not None:
add_to_result_map(
- name,
- orig_name,
- (column, name, column.key),
- column.type
+ name, orig_name, (column, name, column.key), column.type
)
if is_literal:
@@ -721,17 +847,16 @@ class SQLCompiler(Compiled):
effective_schema = self.preparer.schema_for_object(table)
if effective_schema:
- schema_prefix = self.preparer.quote_schema(
- effective_schema) + '.'
+ schema_prefix = (
+ self.preparer.quote_schema(effective_schema) + "."
+ )
else:
- schema_prefix = ''
+ schema_prefix = ""
tablename = table.name
if isinstance(tablename, elements._truncated_label):
tablename = self._truncated_identifier("alias", tablename)
- return schema_prefix + \
- self.preparer.quote(tablename) + \
- "." + name
+ return schema_prefix + self.preparer.quote(tablename) + "." + name
def visit_collation(self, element, **kw):
return self.preparer.format_collation(element.collation)
@@ -743,17 +868,17 @@ class SQLCompiler(Compiled):
return index.name
def visit_typeclause(self, typeclause, **kw):
- kw['type_expression'] = typeclause
+ kw["type_expression"] = typeclause
return self.dialect.type_compiler.process(typeclause.type, **kw)
def post_process_text(self, text):
if self.preparer._double_percents:
- text = text.replace('%', '%%')
+ text = text.replace("%", "%%")
return text
def escape_literal_column(self, text):
if self.preparer._double_percents:
- text = text.replace('%', '%%')
+ text = text.replace("%", "%%")
return text
def visit_textclause(self, textclause, **kw):
@@ -771,30 +896,36 @@ class SQLCompiler(Compiled):
return BIND_PARAMS_ESC.sub(
lambda m: m.group(1),
BIND_PARAMS.sub(
- do_bindparam,
- self.post_process_text(textclause.text))
+ do_bindparam, self.post_process_text(textclause.text)
+ ),
)
- def visit_text_as_from(self, taf,
- compound_index=None,
- asfrom=False,
- parens=True, **kw):
+ def visit_text_as_from(
+ self, taf, compound_index=None, asfrom=False, parens=True, **kw
+ ):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- populate_result_map = toplevel or \
- (
- compound_index == 0 and entry.get(
- 'need_result_map_for_compound', False)
- ) or entry.get('need_result_map_for_nested', False)
+ populate_result_map = (
+ toplevel
+ or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
+ or entry.get("need_result_map_for_nested", False)
+ )
if populate_result_map:
- self._ordered_columns = \
- self._textual_ordered_columns = taf.positional
+ self._ordered_columns = (
+ self._textual_ordered_columns
+ ) = taf.positional
for c in taf.column_args:
- self.process(c, within_columns_clause=True,
- add_to_result_map=self._add_to_result_map)
+ self.process(
+ c,
+ within_columns_clause=True,
+ add_to_result_map=self._add_to_result_map,
+ )
text = self.process(taf.element, **kw)
if asfrom and parens:
@@ -802,17 +933,17 @@ class SQLCompiler(Compiled):
return text
def visit_null(self, expr, **kw):
- return 'NULL'
+ return "NULL"
def visit_true(self, expr, **kw):
if self.dialect.supports_native_boolean:
- return 'true'
+ return "true"
else:
return "1"
def visit_false(self, expr, **kw):
if self.dialect.supports_native_boolean:
- return 'false'
+ return "false"
else:
return "0"
@@ -823,25 +954,29 @@ class SQLCompiler(Compiled):
else:
sep = OPERATORS[clauselist.operator]
return sep.join(
- s for s in
- (
- c._compiler_dispatch(self, **kw)
- for c in clauselist.clauses)
- if s)
+ s
+ for s in (
+ c._compiler_dispatch(self, **kw) for c in clauselist.clauses
+ )
+ if s
+ )
def visit_case(self, clause, **kwargs):
x = "CASE "
if clause.value is not None:
x += clause.value._compiler_dispatch(self, **kwargs) + " "
for cond, result in clause.whens:
- x += "WHEN " + cond._compiler_dispatch(
- self, **kwargs
- ) + " THEN " + result._compiler_dispatch(
- self, **kwargs) + " "
+ x += (
+ "WHEN "
+ + cond._compiler_dispatch(self, **kwargs)
+ + " THEN "
+ + result._compiler_dispatch(self, **kwargs)
+ + " "
+ )
if clause.else_ is not None:
- x += "ELSE " + clause.else_._compiler_dispatch(
- self, **kwargs
- ) + " "
+ x += (
+ "ELSE " + clause.else_._compiler_dispatch(self, **kwargs) + " "
+ )
x += "END"
return x
@@ -849,79 +984,84 @@ class SQLCompiler(Compiled):
return type_coerce.typed_expression._compiler_dispatch(self, **kw)
def visit_cast(self, cast, **kwargs):
- return "CAST(%s AS %s)" % \
- (cast.clause._compiler_dispatch(self, **kwargs),
- cast.typeclause._compiler_dispatch(self, **kwargs))
+ return "CAST(%s AS %s)" % (
+ cast.clause._compiler_dispatch(self, **kwargs),
+ cast.typeclause._compiler_dispatch(self, **kwargs),
+ )
def _format_frame_clause(self, range_, **kw):
- return '%s AND %s' % (
+ return "%s AND %s" % (
"UNBOUNDED PRECEDING"
if range_[0] is elements.RANGE_UNBOUNDED
- else "CURRENT ROW" if range_[0] is elements.RANGE_CURRENT
- else "%s PRECEDING" % (
- self.process(elements.literal(abs(range_[0])), **kw), )
+ else "CURRENT ROW"
+ if range_[0] is elements.RANGE_CURRENT
+ else "%s PRECEDING"
+ % (self.process(elements.literal(abs(range_[0])), **kw),)
if range_[0] < 0
- else "%s FOLLOWING" % (
- self.process(elements.literal(range_[0]), **kw), ),
-
+ else "%s FOLLOWING"
+ % (self.process(elements.literal(range_[0]), **kw),),
"UNBOUNDED FOLLOWING"
if range_[1] is elements.RANGE_UNBOUNDED
- else "CURRENT ROW" if range_[1] is elements.RANGE_CURRENT
- else "%s PRECEDING" % (
- self.process(elements.literal(abs(range_[1])), **kw), )
+ else "CURRENT ROW"
+ if range_[1] is elements.RANGE_CURRENT
+ else "%s PRECEDING"
+ % (self.process(elements.literal(abs(range_[1])), **kw),)
if range_[1] < 0
- else "%s FOLLOWING" % (
- self.process(elements.literal(range_[1]), **kw), ),
+ else "%s FOLLOWING"
+ % (self.process(elements.literal(range_[1]), **kw),),
)
def visit_over(self, over, **kwargs):
if over.range_:
range_ = "RANGE BETWEEN %s" % self._format_frame_clause(
- over.range_, **kwargs)
+ over.range_, **kwargs
+ )
elif over.rows:
range_ = "ROWS BETWEEN %s" % self._format_frame_clause(
- over.rows, **kwargs)
+ over.rows, **kwargs
+ )
else:
range_ = None
return "%s OVER (%s)" % (
over.element._compiler_dispatch(self, **kwargs),
- ' '.join([
- '%s BY %s' % (
- word, clause._compiler_dispatch(self, **kwargs)
- )
- for word, clause in (
- ('PARTITION', over.partition_by),
- ('ORDER', over.order_by)
- )
- if clause is not None and len(clause)
- ] + ([range_] if range_ else [])
- )
+ " ".join(
+ [
+ "%s BY %s"
+ % (word, clause._compiler_dispatch(self, **kwargs))
+ for word, clause in (
+ ("PARTITION", over.partition_by),
+ ("ORDER", over.order_by),
+ )
+ if clause is not None and len(clause)
+ ]
+ + ([range_] if range_ else [])
+ ),
)
def visit_withingroup(self, withingroup, **kwargs):
return "%s WITHIN GROUP (ORDER BY %s)" % (
withingroup.element._compiler_dispatch(self, **kwargs),
- withingroup.order_by._compiler_dispatch(self, **kwargs)
+ withingroup.order_by._compiler_dispatch(self, **kwargs),
)
def visit_funcfilter(self, funcfilter, **kwargs):
return "%s FILTER (WHERE %s)" % (
funcfilter.func._compiler_dispatch(self, **kwargs),
- funcfilter.criterion._compiler_dispatch(self, **kwargs)
+ funcfilter.criterion._compiler_dispatch(self, **kwargs),
)
def visit_extract(self, extract, **kwargs):
field = self.extract_map.get(extract.field, extract.field)
return "EXTRACT(%s FROM %s)" % (
- field, extract.expr._compiler_dispatch(self, **kwargs))
+ field,
+ extract.expr._compiler_dispatch(self, **kwargs),
+ )
def visit_function(self, func, add_to_result_map=None, **kwargs):
if add_to_result_map is not None:
- add_to_result_map(
- func.name, func.name, (), func.type
- )
+ add_to_result_map(func.name, func.name, (), func.type)
disp = getattr(self, "visit_%s_func" % func.name.lower(), None)
if disp:
@@ -933,51 +1073,63 @@ class SQLCompiler(Compiled):
name += "%(expr)s"
else:
name = func.name + "%(expr)s"
- return ".".join(list(func.packagenames) + [name]) % \
- {'expr': self.function_argspec(func, **kwargs)}
+ return ".".join(list(func.packagenames) + [name]) % {
+ "expr": self.function_argspec(func, **kwargs)
+ }
def visit_next_value_func(self, next_value, **kw):
return self.visit_sequence(next_value.sequence)
def visit_sequence(self, sequence, **kw):
raise NotImplementedError(
- "Dialect '%s' does not support sequence increments." %
- self.dialect.name
+ "Dialect '%s' does not support sequence increments."
+ % self.dialect.name
)
def function_argspec(self, func, **kwargs):
return func.clause_expr._compiler_dispatch(self, **kwargs)
- def visit_compound_select(self, cs, asfrom=False,
- parens=True, compound_index=0, **kwargs):
+ def visit_compound_select(
+ self, cs, asfrom=False, parens=True, compound_index=0, **kwargs
+ ):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- need_result_map = toplevel or \
- (compound_index == 0
- and entry.get('need_result_map_for_compound', False))
+ need_result_map = toplevel or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
self.stack.append(
{
- 'correlate_froms': entry['correlate_froms'],
- 'asfrom_froms': entry['asfrom_froms'],
- 'selectable': cs,
- 'need_result_map_for_compound': need_result_map
- })
+ "correlate_froms": entry["correlate_froms"],
+ "asfrom_froms": entry["asfrom_froms"],
+ "selectable": cs,
+ "need_result_map_for_compound": need_result_map,
+ }
+ )
keyword = self.compound_keywords.get(cs.keyword)
text = (" " + keyword + " ").join(
- (c._compiler_dispatch(self,
- asfrom=asfrom, parens=False,
- compound_index=i, **kwargs)
- for i, c in enumerate(cs.selects))
+ (
+ c._compiler_dispatch(
+ self,
+ asfrom=asfrom,
+ parens=False,
+ compound_index=i,
+ **kwargs
+ )
+ for i, c in enumerate(cs.selects)
+ )
)
text += self.group_by_clause(cs, **dict(asfrom=asfrom, **kwargs))
text += self.order_by_clause(cs, **kwargs)
- text += (cs._limit_clause is not None
- or cs._offset_clause is not None) and \
- self.limit_clause(cs, **kwargs) or ""
+ text += (
+ (cs._limit_clause is not None or cs._offset_clause is not None)
+ and self.limit_clause(cs, **kwargs)
+ or ""
+ )
if self.ctes and toplevel:
text = self._render_cte_clause() + text
@@ -990,8 +1142,10 @@ class SQLCompiler(Compiled):
def _get_operator_dispatch(self, operator_, qualifier1, qualifier2):
attrname = "visit_%s_%s%s" % (
- operator_.__name__, qualifier1,
- "_" + qualifier2 if qualifier2 else "")
+ operator_.__name__,
+ qualifier1,
+ "_" + qualifier2 if qualifier2 else "",
+ )
return getattr(self, attrname, None)
def visit_unary(self, unary, **kw):
@@ -999,51 +1153,63 @@ class SQLCompiler(Compiled):
if unary.modifier:
raise exc.CompileError(
"Unary expression does not support operator "
- "and modifier simultaneously")
+ "and modifier simultaneously"
+ )
disp = self._get_operator_dispatch(
- unary.operator, "unary", "operator")
+ unary.operator, "unary", "operator"
+ )
if disp:
return disp(unary, unary.operator, **kw)
else:
return self._generate_generic_unary_operator(
- unary, OPERATORS[unary.operator], **kw)
+ unary, OPERATORS[unary.operator], **kw
+ )
elif unary.modifier:
disp = self._get_operator_dispatch(
- unary.modifier, "unary", "modifier")
+ unary.modifier, "unary", "modifier"
+ )
if disp:
return disp(unary, unary.modifier, **kw)
else:
return self._generate_generic_unary_modifier(
- unary, OPERATORS[unary.modifier], **kw)
+ unary, OPERATORS[unary.modifier], **kw
+ )
else:
raise exc.CompileError(
- "Unary expression has no operator or modifier")
+ "Unary expression has no operator or modifier"
+ )
def visit_istrue_unary_operator(self, element, operator, **kw):
- if element._is_implicitly_boolean or \
- self.dialect.supports_native_boolean:
+ if (
+ element._is_implicitly_boolean
+ or self.dialect.supports_native_boolean
+ ):
return self.process(element.element, **kw)
else:
return "%s = 1" % self.process(element.element, **kw)
def visit_isfalse_unary_operator(self, element, operator, **kw):
- if element._is_implicitly_boolean or \
- self.dialect.supports_native_boolean:
+ if (
+ element._is_implicitly_boolean
+ or self.dialect.supports_native_boolean
+ ):
return "NOT %s" % self.process(element.element, **kw)
else:
return "%s = 0" % self.process(element.element, **kw)
def visit_notmatch_op_binary(self, binary, operator, **kw):
return "NOT %s" % self.visit_binary(
- binary, override_operator=operators.match_op)
+ binary, override_operator=operators.match_op
+ )
def _emit_empty_in_warning(self):
util.warn(
- 'The IN-predicate was invoked with an '
- 'empty sequence. This results in a '
- 'contradiction, which nonetheless can be '
- 'expensive to evaluate. Consider alternative '
- 'strategies for improved performance.')
+ "The IN-predicate was invoked with an "
+ "empty sequence. This results in a "
+ "contradiction, which nonetheless can be "
+ "expensive to evaluate. Consider alternative "
+ "strategies for improved performance."
+ )
def visit_empty_in_op_binary(self, binary, operator, **kw):
if self.dialect._use_static_in:
@@ -1063,18 +1229,21 @@ class SQLCompiler(Compiled):
def visit_empty_set_expr(self, element_types):
raise NotImplementedError(
- "Dialect '%s' does not support empty set expression." %
- self.dialect.name
+ "Dialect '%s' does not support empty set expression."
+ % self.dialect.name
)
- def visit_binary(self, binary, override_operator=None,
- eager_grouping=False, **kw):
+ def visit_binary(
+ self, binary, override_operator=None, eager_grouping=False, **kw
+ ):
# don't allow "? = ?" to render
- if self.ansi_bind_rules and \
- isinstance(binary.left, elements.BindParameter) and \
- isinstance(binary.right, elements.BindParameter):
- kw['literal_binds'] = True
+ if (
+ self.ansi_bind_rules
+ and isinstance(binary.left, elements.BindParameter)
+ and isinstance(binary.right, elements.BindParameter)
+ ):
+ kw["literal_binds"] = True
operator_ = override_operator or binary.operator
disp = self._get_operator_dispatch(operator_, "binary", None)
@@ -1093,36 +1262,50 @@ class SQLCompiler(Compiled):
def visit_mod_binary(self, binary, operator, **kw):
if self.preparer._double_percents:
- return self.process(binary.left, **kw) + " %% " + \
- self.process(binary.right, **kw)
+ return (
+ self.process(binary.left, **kw)
+ + " %% "
+ + self.process(binary.right, **kw)
+ )
else:
- return self.process(binary.left, **kw) + " % " + \
- self.process(binary.right, **kw)
+ return (
+ self.process(binary.left, **kw)
+ + " % "
+ + self.process(binary.right, **kw)
+ )
def visit_custom_op_binary(self, element, operator, **kw):
- kw['eager_grouping'] = operator.eager_grouping
+ kw["eager_grouping"] = operator.eager_grouping
return self._generate_generic_binary(
- element, " " + operator.opstring + " ", **kw)
+ element, " " + operator.opstring + " ", **kw
+ )
def visit_custom_op_unary_operator(self, element, operator, **kw):
return self._generate_generic_unary_operator(
- element, operator.opstring + " ", **kw)
+ element, operator.opstring + " ", **kw
+ )
def visit_custom_op_unary_modifier(self, element, operator, **kw):
return self._generate_generic_unary_modifier(
- element, " " + operator.opstring, **kw)
+ element, " " + operator.opstring, **kw
+ )
def _generate_generic_binary(
- self, binary, opstring, eager_grouping=False, **kw):
+ self, binary, opstring, eager_grouping=False, **kw
+ ):
- _in_binary = kw.get('_in_binary', False)
+ _in_binary = kw.get("_in_binary", False)
- kw['_in_binary'] = True
- text = binary.left._compiler_dispatch(
- self, eager_grouping=eager_grouping, **kw) + \
- opstring + \
- binary.right._compiler_dispatch(
- self, eager_grouping=eager_grouping, **kw)
+ kw["_in_binary"] = True
+ text = (
+ binary.left._compiler_dispatch(
+ self, eager_grouping=eager_grouping, **kw
+ )
+ + opstring
+ + binary.right._compiler_dispatch(
+ self, eager_grouping=eager_grouping, **kw
+ )
+ )
if _in_binary and eager_grouping:
text = "(%s)" % text
@@ -1153,17 +1336,13 @@ class SQLCompiler(Compiled):
def visit_startswith_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
- binary.right = percent.__radd__(
- binary.right
- )
+ binary.right = percent.__radd__(binary.right)
return self.visit_like_op_binary(binary, operator, **kw)
def visit_notstartswith_op_binary(self, binary, operator, **kw):
binary = binary._clone()
percent = self._like_percent_literal
- binary.right = percent.__radd__(
- binary.right
- )
+ binary.right = percent.__radd__(binary.right)
return self.visit_notlike_op_binary(binary, operator, **kw)
def visit_endswith_op_binary(self, binary, operator, **kw):
@@ -1182,98 +1361,105 @@ class SQLCompiler(Compiled):
escape = binary.modifiers.get("escape", None)
# TODO: use ternary here, not "and"/ "or"
- return '%s LIKE %s' % (
+ return "%s LIKE %s" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_notlike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
- return '%s NOT LIKE %s' % (
+ return "%s NOT LIKE %s" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_ilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
- return 'lower(%s) LIKE lower(%s)' % (
+ return "lower(%s) LIKE lower(%s)" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_notilike_op_binary(self, binary, operator, **kw):
escape = binary.modifiers.get("escape", None)
- return 'lower(%s) NOT LIKE lower(%s)' % (
+ return "lower(%s) NOT LIKE lower(%s)" % (
binary.left._compiler_dispatch(self, **kw),
- binary.right._compiler_dispatch(self, **kw)) \
- + (
- ' ESCAPE ' +
- self.render_literal_value(escape, sqltypes.STRINGTYPE)
- if escape else ''
- )
+ binary.right._compiler_dispatch(self, **kw),
+ ) + (
+ " ESCAPE " + self.render_literal_value(escape, sqltypes.STRINGTYPE)
+ if escape
+ else ""
+ )
def visit_between_op_binary(self, binary, operator, **kw):
symmetric = binary.modifiers.get("symmetric", False)
return self._generate_generic_binary(
- binary, " BETWEEN SYMMETRIC "
- if symmetric else " BETWEEN ", **kw)
+ binary, " BETWEEN SYMMETRIC " if symmetric else " BETWEEN ", **kw
+ )
def visit_notbetween_op_binary(self, binary, operator, **kw):
symmetric = binary.modifiers.get("symmetric", False)
return self._generate_generic_binary(
- binary, " NOT BETWEEN SYMMETRIC "
- if symmetric else " NOT BETWEEN ", **kw)
+ binary,
+ " NOT BETWEEN SYMMETRIC " if symmetric else " NOT BETWEEN ",
+ **kw
+ )
- def visit_bindparam(self, bindparam, within_columns_clause=False,
- literal_binds=False,
- skip_bind_expression=False,
- **kwargs):
+ def visit_bindparam(
+ self,
+ bindparam,
+ within_columns_clause=False,
+ literal_binds=False,
+ skip_bind_expression=False,
+ **kwargs
+ ):
if not skip_bind_expression:
impl = bindparam.type.dialect_impl(self.dialect)
if impl._has_bind_expression:
bind_expression = impl.bind_expression(bindparam)
return self.process(
- bind_expression, skip_bind_expression=True,
+ bind_expression,
+ skip_bind_expression=True,
within_columns_clause=within_columns_clause,
literal_binds=literal_binds,
**kwargs
)
- if literal_binds or \
- (within_columns_clause and
- self.ansi_bind_rules):
+ if literal_binds or (within_columns_clause and self.ansi_bind_rules):
if bindparam.value is None and bindparam.callable is None:
- raise exc.CompileError("Bind parameter '%s' without a "
- "renderable value not allowed here."
- % bindparam.key)
+ raise exc.CompileError(
+ "Bind parameter '%s' without a "
+ "renderable value not allowed here." % bindparam.key
+ )
return self.render_literal_bindparam(
- bindparam, within_columns_clause=True, **kwargs)
+ bindparam, within_columns_clause=True, **kwargs
+ )
name = self._truncate_bindparam(bindparam)
if name in self.binds:
existing = self.binds[name]
if existing is not bindparam:
- if (existing.unique or bindparam.unique) and \
- not existing.proxy_set.intersection(
- bindparam.proxy_set):
+ if (
+ existing.unique or bindparam.unique
+ ) and not existing.proxy_set.intersection(bindparam.proxy_set):
raise exc.CompileError(
"Bind parameter '%s' conflicts with "
- "unique bind parameter of the same name" %
- bindparam.key
+ "unique bind parameter of the same name"
+ % bindparam.key
)
elif existing._is_crud or bindparam._is_crud:
raise exc.CompileError(
@@ -1282,14 +1468,15 @@ class SQLCompiler(Compiled):
"clause of this "
"insert/update statement. Please use a "
"name other than column name when using bindparam() "
- "with insert() or update() (for example, 'b_%s')." %
- (bindparam.key, bindparam.key)
+ "with insert() or update() (for example, 'b_%s')."
+ % (bindparam.key, bindparam.key)
)
self.binds[bindparam.key] = self.binds[name] = bindparam
return self.bindparam_string(
- name, expanding=bindparam.expanding, **kwargs)
+ name, expanding=bindparam.expanding, **kwargs
+ )
def render_literal_bindparam(self, bindparam, **kw):
value = bindparam.effective_value
@@ -1311,7 +1498,8 @@ class SQLCompiler(Compiled):
return processor(value)
else:
raise NotImplementedError(
- "Don't know how to literal-quote value %r" % value)
+ "Don't know how to literal-quote value %r" % value
+ )
def _truncate_bindparam(self, bindparam):
if bindparam in self.bind_names:
@@ -1334,8 +1522,11 @@ class SQLCompiler(Compiled):
if len(anonname) > self.label_length - 6:
counter = self.truncated_names.get(ident_class, 1)
- truncname = anonname[0:max(self.label_length - 6, 0)] + \
- "_" + hex(counter)[2:]
+ truncname = (
+ anonname[0 : max(self.label_length - 6, 0)]
+ + "_"
+ + hex(counter)[2:]
+ )
self.truncated_names[ident_class] = counter + 1
else:
truncname = anonname
@@ -1346,13 +1537,14 @@ class SQLCompiler(Compiled):
return name % self.anon_map
def _process_anon(self, key):
- (ident, derived) = key.split(' ', 1)
+ (ident, derived) = key.split(" ", 1)
anonymous_counter = self.anon_map.get(derived, 1)
self.anon_map[derived] = anonymous_counter + 1
return derived + "_" + str(anonymous_counter)
def bindparam_string(
- self, name, positional_names=None, expanding=False, **kw):
+ self, name, positional_names=None, expanding=False, **kw
+ ):
if self.positional:
if positional_names is not None:
positional_names.append(name)
@@ -1362,14 +1554,20 @@ class SQLCompiler(Compiled):
self.contains_expanding_parameters = True
return "([EXPANDING_%s])" % name
else:
- return self.bindtemplate % {'name': name}
-
- def visit_cte(self, cte, asfrom=False, ashint=False,
- fromhints=None, visiting_cte=None,
- **kwargs):
+ return self.bindtemplate % {"name": name}
+
+ def visit_cte(
+ self,
+ cte,
+ asfrom=False,
+ ashint=False,
+ fromhints=None,
+ visiting_cte=None,
+ **kwargs
+ ):
self._init_cte_state()
- kwargs['visiting_cte'] = cte
+ kwargs["visiting_cte"] = cte
if isinstance(cte.name, elements._truncated_label):
cte_name = self._truncated_identifier("alias", cte.name)
else:
@@ -1394,8 +1592,8 @@ class SQLCompiler(Compiled):
else:
raise exc.CompileError(
"Multiple, unrelated CTEs found with "
- "the same name: %r" %
- cte_name)
+ "the same name: %r" % cte_name
+ )
if asfrom or is_new_cte:
if cte._cte_alias is not None:
@@ -1403,7 +1601,8 @@ class SQLCompiler(Compiled):
cte_pre_alias_name = cte._cte_alias.name
if isinstance(cte_pre_alias_name, elements._truncated_label):
cte_pre_alias_name = self._truncated_identifier(
- "alias", cte_pre_alias_name)
+ "alias", cte_pre_alias_name
+ )
else:
pre_alias_cte = cte
cte_pre_alias_name = None
@@ -1412,11 +1611,17 @@ class SQLCompiler(Compiled):
self.ctes_by_name[cte_name] = cte
# look for embedded DML ctes and propagate autocommit
- if 'autocommit' in cte.element._execution_options and \
- 'autocommit' not in self.execution_options:
+ if (
+ "autocommit" in cte.element._execution_options
+ and "autocommit" not in self.execution_options
+ ):
self.execution_options = self.execution_options.union(
- {"autocommit":
- cte.element._execution_options['autocommit']})
+ {
+ "autocommit": cte.element._execution_options[
+ "autocommit"
+ ]
+ }
+ )
if pre_alias_cte not in self.ctes:
self.visit_cte(pre_alias_cte, **kwargs)
@@ -1432,25 +1637,30 @@ class SQLCompiler(Compiled):
col_source = cte.original.selects[0]
else:
assert False
- recur_cols = [c for c in
- util.unique_list(col_source.inner_columns)
- if c is not None]
-
- text += "(%s)" % (", ".join(
- self.preparer.format_column(ident)
- for ident in recur_cols))
+ recur_cols = [
+ c
+ for c in util.unique_list(col_source.inner_columns)
+ if c is not None
+ ]
+
+ text += "(%s)" % (
+ ", ".join(
+ self.preparer.format_column(ident)
+ for ident in recur_cols
+ )
+ )
if self.positional:
- kwargs['positional_names'] = self.cte_positional[cte] = []
+ kwargs["positional_names"] = self.cte_positional[cte] = []
- text += " AS \n" + \
- cte.original._compiler_dispatch(
- self, asfrom=True, **kwargs
- )
+ text += " AS \n" + cte.original._compiler_dispatch(
+ self, asfrom=True, **kwargs
+ )
if cte._suffixes:
text += " " + self._generate_prefixes(
- cte, cte._suffixes, **kwargs)
+ cte, cte._suffixes, **kwargs
+ )
self.ctes[cte] = text
@@ -1467,9 +1677,15 @@ class SQLCompiler(Compiled):
else:
return self.preparer.format_alias(cte, cte_name)
- def visit_alias(self, alias, asfrom=False, ashint=False,
- iscrud=False,
- fromhints=None, **kwargs):
+ def visit_alias(
+ self,
+ alias,
+ asfrom=False,
+ ashint=False,
+ iscrud=False,
+ fromhints=None,
+ **kwargs
+ ):
if asfrom or ashint:
if isinstance(alias.name, elements._truncated_label):
alias_name = self._truncated_identifier("alias", alias.name)
@@ -1479,31 +1695,35 @@ class SQLCompiler(Compiled):
if ashint:
return self.preparer.format_alias(alias, alias_name)
elif asfrom:
- ret = alias.original._compiler_dispatch(self,
- asfrom=True, **kwargs) + \
- self.get_render_as_alias_suffix(
- self.preparer.format_alias(alias, alias_name))
+ ret = alias.original._compiler_dispatch(
+ self, asfrom=True, **kwargs
+ ) + self.get_render_as_alias_suffix(
+ self.preparer.format_alias(alias, alias_name)
+ )
if fromhints and alias in fromhints:
- ret = self.format_from_hint_text(ret, alias,
- fromhints[alias], iscrud)
+ ret = self.format_from_hint_text(
+ ret, alias, fromhints[alias], iscrud
+ )
return ret
else:
return alias.original._compiler_dispatch(self, **kwargs)
def visit_lateral(self, lateral, **kw):
- kw['lateral'] = True
+ kw["lateral"] = True
return "LATERAL %s" % self.visit_alias(lateral, **kw)
def visit_tablesample(self, tablesample, asfrom=False, **kw):
text = "%s TABLESAMPLE %s" % (
self.visit_alias(tablesample, asfrom=True, **kw),
- tablesample._get_method()._compiler_dispatch(self, **kw))
+ tablesample._get_method()._compiler_dispatch(self, **kw),
+ )
if tablesample.seed is not None:
text += " REPEATABLE (%s)" % (
- tablesample.seed._compiler_dispatch(self, **kw))
+ tablesample.seed._compiler_dispatch(self, **kw)
+ )
return text
@@ -1513,22 +1733,27 @@ class SQLCompiler(Compiled):
def _add_to_result_map(self, keyname, name, objects, type_):
self._result_columns.append((keyname, name, objects, type_))
- def _label_select_column(self, select, column,
- populate_result_map,
- asfrom, column_clause_args,
- name=None,
- within_columns_clause=True):
+ def _label_select_column(
+ self,
+ select,
+ column,
+ populate_result_map,
+ asfrom,
+ column_clause_args,
+ name=None,
+ within_columns_clause=True,
+ ):
"""produce labeled columns present in a select()."""
impl = column.type.dialect_impl(self.dialect)
- if impl._has_column_expression and \
- populate_result_map:
+ if impl._has_column_expression and populate_result_map:
col_expr = impl.column_expression(column)
def add_to_result_map(keyname, name, objects, type_):
self._add_to_result_map(
- keyname, name,
- (column,) + objects, type_)
+ keyname, name, (column,) + objects, type_
+ )
+
else:
col_expr = column
if populate_result_map:
@@ -1541,58 +1766,56 @@ class SQLCompiler(Compiled):
elif isinstance(column, elements.Label):
if col_expr is not column:
result_expr = _CompileLabel(
- col_expr,
- column.name,
- alt_names=(column.element,)
+ col_expr, column.name, alt_names=(column.element,)
)
else:
result_expr = col_expr
elif select is not None and name:
result_expr = _CompileLabel(
+ col_expr, name, alt_names=(column._key_label,)
+ )
+
+ elif (
+ asfrom
+ and isinstance(column, elements.ColumnClause)
+ and not column.is_literal
+ and column.table is not None
+ and not isinstance(column.table, selectable.Select)
+ ):
+ result_expr = _CompileLabel(
col_expr,
- name,
- alt_names=(column._key_label,)
- )
-
- elif \
- asfrom and \
- isinstance(column, elements.ColumnClause) and \
- not column.is_literal and \
- column.table is not None and \
- not isinstance(column.table, selectable.Select):
- result_expr = _CompileLabel(col_expr,
- elements._as_truncated(column.name),
- alt_names=(column.key,))
+ elements._as_truncated(column.name),
+ alt_names=(column.key,),
+ )
elif (
- not isinstance(column, elements.TextClause) and
- (
- not isinstance(column, elements.UnaryExpression) or
- column.wraps_column_expression
- ) and
- (
- not hasattr(column, 'name') or
- isinstance(column, functions.Function)
+ not isinstance(column, elements.TextClause)
+ and (
+ not isinstance(column, elements.UnaryExpression)
+ or column.wraps_column_expression
+ )
+ and (
+ not hasattr(column, "name")
+ or isinstance(column, functions.Function)
)
):
result_expr = _CompileLabel(col_expr, column.anon_label)
elif col_expr is not column:
# TODO: are we sure "column" has a .name and .key here ?
# assert isinstance(column, elements.ColumnClause)
- result_expr = _CompileLabel(col_expr,
- elements._as_truncated(column.name),
- alt_names=(column.key,))
+ result_expr = _CompileLabel(
+ col_expr,
+ elements._as_truncated(column.name),
+ alt_names=(column.key,),
+ )
else:
result_expr = col_expr
column_clause_args.update(
within_columns_clause=within_columns_clause,
- add_to_result_map=add_to_result_map
- )
- return result_expr._compiler_dispatch(
- self,
- **column_clause_args
+ add_to_result_map=add_to_result_map,
)
+ return result_expr._compiler_dispatch(self, **column_clause_args)
def format_from_hint_text(self, sqltext, table, hint, iscrud):
hinttext = self.get_from_hint_text(table, hint)
@@ -1631,8 +1854,11 @@ class SQLCompiler(Compiled):
newelem = cloned[element] = element._clone()
- if newelem.is_selectable and newelem._is_join and \
- isinstance(newelem.right, selectable.FromGrouping):
+ if (
+ newelem.is_selectable
+ and newelem._is_join
+ and isinstance(newelem.right, selectable.FromGrouping)
+ ):
newelem._reset_exported()
newelem.left = visit(newelem.left, **kw)
@@ -1640,8 +1866,8 @@ class SQLCompiler(Compiled):
right = visit(newelem.right, **kw)
selectable_ = selectable.Select(
- [right.element],
- use_labels=True).alias()
+ [right.element], use_labels=True
+ ).alias()
for c in selectable_.c:
c._key_label = c.key
@@ -1680,17 +1906,18 @@ class SQLCompiler(Compiled):
elif newelem._is_from_container:
# if we hit an Alias, CompoundSelect or ScalarSelect, put a
# marker in the stack.
- kw['transform_clue'] = 'select_container'
+ kw["transform_clue"] = "select_container"
newelem._copy_internals(clone=visit, **kw)
elif newelem.is_selectable and newelem._is_select:
- barrier_select = kw.get('transform_clue', None) == \
- 'select_container'
+ barrier_select = (
+ kw.get("transform_clue", None) == "select_container"
+ )
# if we're still descended from an
# Alias/CompoundSelect/ScalarSelect, we're
# in a FROM clause, so start with a new translate collection
if barrier_select:
column_translate.append({})
- kw['transform_clue'] = 'inside_select'
+ kw["transform_clue"] = "inside_select"
newelem._copy_internals(clone=visit, **kw)
if barrier_select:
del column_translate[-1]
@@ -1702,24 +1929,22 @@ class SQLCompiler(Compiled):
return visit(select)
def _transform_result_map_for_nested_joins(
- self, select, transformed_select):
- inner_col = dict((c._key_label, c) for
- c in transformed_select.inner_columns)
-
- d = dict(
- (inner_col[c._key_label], c)
- for c in select.inner_columns
+ self, select, transformed_select
+ ):
+ inner_col = dict(
+ (c._key_label, c) for c in transformed_select.inner_columns
)
+ d = dict((inner_col[c._key_label], c) for c in select.inner_columns)
+
self._result_columns = [
(key, name, tuple([d.get(col, col) for col in objs]), typ)
for key, name, objs, typ in self._result_columns
]
- _default_stack_entry = util.immutabledict([
- ('correlate_froms', frozenset()),
- ('asfrom_froms', frozenset())
- ])
+ _default_stack_entry = util.immutabledict(
+ [("correlate_froms", frozenset()), ("asfrom_froms", frozenset())]
+ )
def _display_froms_for_select(self, select, asfrom, lateral=False):
# utility method to help external dialects
@@ -1729,72 +1954,88 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- correlate_froms = entry['correlate_froms']
- asfrom_froms = entry['asfrom_froms']
+ correlate_froms = entry["correlate_froms"]
+ asfrom_froms = entry["asfrom_froms"]
if asfrom and not lateral:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms.difference(
- asfrom_froms),
- implicit_correlate_froms=())
+ asfrom_froms
+ ),
+ implicit_correlate_froms=(),
+ )
else:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms,
- implicit_correlate_froms=asfrom_froms)
+ implicit_correlate_froms=asfrom_froms,
+ )
return froms
- def visit_select(self, select, asfrom=False, parens=True,
- fromhints=None,
- compound_index=0,
- nested_join_translation=False,
- select_wraps_for=None,
- lateral=False,
- **kwargs):
-
- needs_nested_translation = \
- select.use_labels and \
- not nested_join_translation and \
- not self.stack and \
- not self.dialect.supports_right_nested_joins
+ def visit_select(
+ self,
+ select,
+ asfrom=False,
+ parens=True,
+ fromhints=None,
+ compound_index=0,
+ nested_join_translation=False,
+ select_wraps_for=None,
+ lateral=False,
+ **kwargs
+ ):
+
+ needs_nested_translation = (
+ select.use_labels
+ and not nested_join_translation
+ and not self.stack
+ and not self.dialect.supports_right_nested_joins
+ )
if needs_nested_translation:
transformed_select = self._transform_select_for_nested_joins(
- select)
+ select
+ )
text = self.visit_select(
- transformed_select, asfrom=asfrom, parens=parens,
+ transformed_select,
+ asfrom=asfrom,
+ parens=parens,
fromhints=fromhints,
compound_index=compound_index,
- nested_join_translation=True, **kwargs
+ nested_join_translation=True,
+ **kwargs
)
toplevel = not self.stack
entry = self._default_stack_entry if toplevel else self.stack[-1]
- populate_result_map = toplevel or \
- (
- compound_index == 0 and entry.get(
- 'need_result_map_for_compound', False)
- ) or entry.get('need_result_map_for_nested', False)
+ populate_result_map = (
+ toplevel
+ or (
+ compound_index == 0
+ and entry.get("need_result_map_for_compound", False)
+ )
+ or entry.get("need_result_map_for_nested", False)
+ )
# this was first proposed as part of #3372; however, it is not
# reached in current tests and could possibly be an assertion
# instead.
- if not populate_result_map and 'add_to_result_map' in kwargs:
- del kwargs['add_to_result_map']
+ if not populate_result_map and "add_to_result_map" in kwargs:
+ del kwargs["add_to_result_map"]
if needs_nested_translation:
if populate_result_map:
self._transform_result_map_for_nested_joins(
- select, transformed_select)
+ select, transformed_select
+ )
return text
froms = self._setup_select_stack(select, entry, asfrom, lateral)
column_clause_args = kwargs.copy()
- column_clause_args.update({
- 'within_label_clause': False,
- 'within_columns_clause': False
- })
+ column_clause_args.update(
+ {"within_label_clause": False, "within_columns_clause": False}
+ )
text = "SELECT " # we're off to a good start !
@@ -1806,19 +2047,21 @@ class SQLCompiler(Compiled):
byfrom = None
if select._prefixes:
- text += self._generate_prefixes(
- select, select._prefixes, **kwargs)
+ text += self._generate_prefixes(select, select._prefixes, **kwargs)
text += self.get_select_precolumns(select, **kwargs)
# the actual list of columns to print in the SELECT column list.
inner_columns = [
- c for c in [
+ c
+ for c in [
self._label_select_column(
select,
column,
- populate_result_map, asfrom,
+ populate_result_map,
+ asfrom,
column_clause_args,
- name=name)
+ name=name,
+ )
for name, column in select._columns_plus_names
]
if c is not None
@@ -1831,8 +2074,11 @@ class SQLCompiler(Compiled):
translate = dict(
zip(
[name for (key, name) in select._columns_plus_names],
- [name for (key, name) in
- select_wraps_for._columns_plus_names])
+ [
+ name
+ for (key, name) in select_wraps_for._columns_plus_names
+ ],
+ )
)
self._result_columns = [
@@ -1841,13 +2087,14 @@ class SQLCompiler(Compiled):
]
text = self._compose_select_body(
- text, select, inner_columns, froms, byfrom, kwargs)
+ text, select, inner_columns, froms, byfrom, kwargs
+ )
if select._statement_hints:
per_dialect = [
- ht for (dialect_name, ht)
- in select._statement_hints
- if dialect_name in ('*', self.dialect.name)
+ ht
+ for (dialect_name, ht) in select._statement_hints
+ if dialect_name in ("*", self.dialect.name)
]
if per_dialect:
text += " " + self.get_statement_hint_text(per_dialect)
@@ -1857,7 +2104,8 @@ class SQLCompiler(Compiled):
if select._suffixes:
text += " " + self._generate_prefixes(
- select, select._suffixes, **kwargs)
+ select, select._suffixes, **kwargs
+ )
self.stack.pop(-1)
@@ -1867,60 +2115,73 @@ class SQLCompiler(Compiled):
return text
def _setup_select_hints(self, select):
- byfrom = dict([
- (from_, hinttext % {
- 'name': from_._compiler_dispatch(
- self, ashint=True)
- })
- for (from_, dialect), hinttext in
- select._hints.items()
- if dialect in ('*', self.dialect.name)
- ])
+ byfrom = dict(
+ [
+ (
+ from_,
+ hinttext
+ % {"name": from_._compiler_dispatch(self, ashint=True)},
+ )
+ for (from_, dialect), hinttext in select._hints.items()
+ if dialect in ("*", self.dialect.name)
+ ]
+ )
hint_text = self.get_select_hint_text(byfrom)
return hint_text, byfrom
def _setup_select_stack(self, select, entry, asfrom, lateral):
- correlate_froms = entry['correlate_froms']
- asfrom_froms = entry['asfrom_froms']
+ correlate_froms = entry["correlate_froms"]
+ asfrom_froms = entry["asfrom_froms"]
if asfrom and not lateral:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms.difference(
- asfrom_froms),
- implicit_correlate_froms=())
+ asfrom_froms
+ ),
+ implicit_correlate_froms=(),
+ )
else:
froms = select._get_display_froms(
explicit_correlate_froms=correlate_froms,
- implicit_correlate_froms=asfrom_froms)
+ implicit_correlate_froms=asfrom_froms,
+ )
new_correlate_froms = set(selectable._from_objects(*froms))
all_correlate_froms = new_correlate_froms.union(correlate_froms)
new_entry = {
- 'asfrom_froms': new_correlate_froms,
- 'correlate_froms': all_correlate_froms,
- 'selectable': select,
+ "asfrom_froms": new_correlate_froms,
+ "correlate_froms": all_correlate_froms,
+ "selectable": select,
}
self.stack.append(new_entry)
return froms
def _compose_select_body(
- self, text, select, inner_columns, froms, byfrom, kwargs):
- text += ', '.join(inner_columns)
+ self, text, select, inner_columns, froms, byfrom, kwargs
+ ):
+ text += ", ".join(inner_columns)
if froms:
text += " \nFROM "
if select._hints:
- text += ', '.join(
- [f._compiler_dispatch(self, asfrom=True,
- fromhints=byfrom, **kwargs)
- for f in froms])
+ text += ", ".join(
+ [
+ f._compiler_dispatch(
+ self, asfrom=True, fromhints=byfrom, **kwargs
+ )
+ for f in froms
+ ]
+ )
else:
- text += ', '.join(
- [f._compiler_dispatch(self, asfrom=True, **kwargs)
- for f in froms])
+ text += ", ".join(
+ [
+ f._compiler_dispatch(self, asfrom=True, **kwargs)
+ for f in froms
+ ]
+ )
else:
text += self.default_from()
@@ -1940,8 +2201,10 @@ class SQLCompiler(Compiled):
if select._order_by_clause.clauses:
text += self.order_by_clause(select, **kwargs)
- if (select._limit_clause is not None or
- select._offset_clause is not None):
+ if (
+ select._limit_clause is not None
+ or select._offset_clause is not None
+ ):
text += self.limit_clause(select, **kwargs)
if select._for_update_arg is not None:
@@ -1953,8 +2216,7 @@ class SQLCompiler(Compiled):
clause = " ".join(
prefix._compiler_dispatch(self, **kw)
for prefix, dialect_name in prefixes
- if dialect_name is None or
- dialect_name == self.dialect.name
+ if dialect_name is None or dialect_name == self.dialect.name
)
if clause:
clause += " "
@@ -1962,14 +2224,12 @@ class SQLCompiler(Compiled):
def _render_cte_clause(self):
if self.positional:
- self.positiontup = sum([
- self.cte_positional[cte]
- for cte in self.ctes], []) + \
- self.positiontup
+ self.positiontup = (
+ sum([self.cte_positional[cte] for cte in self.ctes], [])
+ + self.positiontup
+ )
cte_text = self.get_cte_preamble(self.ctes_recursive) + " "
- cte_text += ", \n".join(
- [txt for txt in self.ctes.values()]
- )
+ cte_text += ", \n".join([txt for txt in self.ctes.values()])
cte_text += "\n "
return cte_text
@@ -2010,7 +2270,8 @@ class SQLCompiler(Compiled):
def returning_clause(self, stmt, returning_cols):
raise exc.CompileError(
"RETURNING is not supported by this "
- "dialect's statement compiler.")
+ "dialect's statement compiler."
+ )
def limit_clause(self, select, **kw):
text = ""
@@ -2022,19 +2283,31 @@ class SQLCompiler(Compiled):
text += " OFFSET " + self.process(select._offset_clause, **kw)
return text
- def visit_table(self, table, asfrom=False, iscrud=False, ashint=False,
- fromhints=None, use_schema=True, **kwargs):
+ def visit_table(
+ self,
+ table,
+ asfrom=False,
+ iscrud=False,
+ ashint=False,
+ fromhints=None,
+ use_schema=True,
+ **kwargs
+ ):
if asfrom or ashint:
effective_schema = self.preparer.schema_for_object(table)
if use_schema and effective_schema:
- ret = self.preparer.quote_schema(effective_schema) + \
- "." + self.preparer.quote(table.name)
+ ret = (
+ self.preparer.quote_schema(effective_schema)
+ + "."
+ + self.preparer.quote(table.name)
+ )
else:
ret = self.preparer.quote(table.name)
if fromhints and table in fromhints:
- ret = self.format_from_hint_text(ret, table,
- fromhints[table], iscrud)
+ ret = self.format_from_hint_text(
+ ret, table, fromhints[table], iscrud
+ )
return ret
else:
return ""
@@ -2047,26 +2320,24 @@ class SQLCompiler(Compiled):
else:
join_type = " JOIN "
return (
- join.left._compiler_dispatch(self, asfrom=True, **kwargs) +
- join_type +
- join.right._compiler_dispatch(self, asfrom=True, **kwargs) +
- " ON " +
- join.onclause._compiler_dispatch(self, **kwargs)
+ join.left._compiler_dispatch(self, asfrom=True, **kwargs)
+ + join_type
+ + join.right._compiler_dispatch(self, asfrom=True, **kwargs)
+ + " ON "
+ + join.onclause._compiler_dispatch(self, **kwargs)
)
def _setup_crud_hints(self, stmt, table_text):
- dialect_hints = dict([
- (table, hint_text)
- for (table, dialect), hint_text in
- stmt._hints.items()
- if dialect in ('*', self.dialect.name)
- ])
+ dialect_hints = dict(
+ [
+ (table, hint_text)
+ for (table, dialect), hint_text in stmt._hints.items()
+ if dialect in ("*", self.dialect.name)
+ ]
+ )
if stmt.table in dialect_hints:
table_text = self.format_from_hint_text(
- table_text,
- stmt.table,
- dialect_hints[stmt.table],
- True
+ table_text, stmt.table, dialect_hints[stmt.table], True
)
return dialect_hints, table_text
@@ -2074,28 +2345,35 @@ class SQLCompiler(Compiled):
toplevel = not self.stack
self.stack.append(
- {'correlate_froms': set(),
- "asfrom_froms": set(),
- "selectable": insert_stmt})
+ {
+ "correlate_froms": set(),
+ "asfrom_froms": set(),
+ "selectable": insert_stmt,
+ }
+ )
crud_params = crud._setup_crud_params(
- self, insert_stmt, crud.ISINSERT, **kw)
+ self, insert_stmt, crud.ISINSERT, **kw
+ )
- if not crud_params and \
- not self.dialect.supports_default_values and \
- not self.dialect.supports_empty_insert:
- raise exc.CompileError("The '%s' dialect with current database "
- "version settings does not support empty "
- "inserts." %
- self.dialect.name)
+ if (
+ not crud_params
+ and not self.dialect.supports_default_values
+ and not self.dialect.supports_empty_insert
+ ):
+ raise exc.CompileError(
+ "The '%s' dialect with current database "
+ "version settings does not support empty "
+ "inserts." % self.dialect.name
+ )
if insert_stmt._has_multi_parameters:
if not self.dialect.supports_multivalues_insert:
raise exc.CompileError(
"The '%s' dialect with current database "
"version settings does not support "
- "in-place multirow inserts." %
- self.dialect.name)
+ "in-place multirow inserts." % self.dialect.name
+ )
crud_params_single = crud_params[0]
else:
crud_params_single = crud_params
@@ -2106,27 +2384,31 @@ class SQLCompiler(Compiled):
text = "INSERT "
if insert_stmt._prefixes:
- text += self._generate_prefixes(insert_stmt,
- insert_stmt._prefixes, **kw)
+ text += self._generate_prefixes(
+ insert_stmt, insert_stmt._prefixes, **kw
+ )
text += "INTO "
table_text = preparer.format_table(insert_stmt.table)
if insert_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
- insert_stmt, table_text)
+ insert_stmt, table_text
+ )
else:
dialect_hints = None
text += table_text
if crud_params_single or not supports_default_values:
- text += " (%s)" % ', '.join([preparer.format_column(c[0])
- for c in crud_params_single])
+ text += " (%s)" % ", ".join(
+ [preparer.format_column(c[0]) for c in crud_params_single]
+ )
if self.returning or insert_stmt._returning:
returning_clause = self.returning_clause(
- insert_stmt, self.returning or insert_stmt._returning)
+ insert_stmt, self.returning or insert_stmt._returning
+ )
if self.returning_precedes_values:
text += " " + returning_clause
@@ -2145,19 +2427,17 @@ class SQLCompiler(Compiled):
elif insert_stmt._has_multi_parameters:
text += " VALUES %s" % (
", ".join(
- "(%s)" % (
- ', '.join(c[1] for c in crud_param_set)
- )
+ "(%s)" % (", ".join(c[1] for c in crud_param_set))
for crud_param_set in crud_params
)
)
else:
- text += " VALUES (%s)" % \
- ', '.join([c[1] for c in crud_params])
+ text += " VALUES (%s)" % ", ".join([c[1] for c in crud_params])
if insert_stmt._post_values_clause is not None:
post_values_clause = self.process(
- insert_stmt._post_values_clause, **kw)
+ insert_stmt._post_values_clause, **kw
+ )
if post_values_clause:
text += " " + post_values_clause
@@ -2178,21 +2458,19 @@ class SQLCompiler(Compiled):
"""Provide a hook for MySQL to add LIMIT to the UPDATE"""
return None
- def update_tables_clause(self, update_stmt, from_table,
- extra_froms, **kw):
+ def update_tables_clause(self, update_stmt, from_table, extra_froms, **kw):
"""Provide a hook to override the initial table clause
in an UPDATE statement.
MySQL overrides this.
"""
- kw['asfrom'] = True
+ kw["asfrom"] = True
return from_table._compiler_dispatch(self, iscrud=True, **kw)
- def update_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
"""Provide a hook to override the generation of an
UPDATE..FROM clause.
@@ -2201,7 +2479,8 @@ class SQLCompiler(Compiled):
"""
raise NotImplementedError(
"This backend does not support multiple-table "
- "criteria within UPDATE")
+ "criteria within UPDATE"
+ )
def visit_update(self, update_stmt, asfrom=False, **kw):
toplevel = not self.stack
@@ -2221,49 +2500,61 @@ class SQLCompiler(Compiled):
correlate_froms = {update_stmt.table}
self.stack.append(
- {'correlate_froms': correlate_froms,
- "asfrom_froms": correlate_froms,
- "selectable": update_stmt})
+ {
+ "correlate_froms": correlate_froms,
+ "asfrom_froms": correlate_froms,
+ "selectable": update_stmt,
+ }
+ )
text = "UPDATE "
if update_stmt._prefixes:
- text += self._generate_prefixes(update_stmt,
- update_stmt._prefixes, **kw)
+ text += self._generate_prefixes(
+ update_stmt, update_stmt._prefixes, **kw
+ )
- table_text = self.update_tables_clause(update_stmt, update_stmt.table,
- render_extra_froms, **kw)
+ table_text = self.update_tables_clause(
+ update_stmt, update_stmt.table, render_extra_froms, **kw
+ )
crud_params = crud._setup_crud_params(
- self, update_stmt, crud.ISUPDATE, **kw)
+ self, update_stmt, crud.ISUPDATE, **kw
+ )
if update_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
- update_stmt, table_text)
+ update_stmt, table_text
+ )
else:
dialect_hints = None
text += table_text
- text += ' SET '
- include_table = is_multitable and \
- self.render_table_with_column_in_update_from
- text += ', '.join(
- c[0]._compiler_dispatch(self,
- include_table=include_table) +
- '=' + c[1] for c in crud_params
+ text += " SET "
+ include_table = (
+ is_multitable and self.render_table_with_column_in_update_from
+ )
+ text += ", ".join(
+ c[0]._compiler_dispatch(self, include_table=include_table)
+ + "="
+ + c[1]
+ for c in crud_params
)
if self.returning or update_stmt._returning:
if self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, self.returning or update_stmt._returning)
+ update_stmt, self.returning or update_stmt._returning
+ )
if extra_froms:
extra_from_text = self.update_from_clause(
update_stmt,
update_stmt.table,
render_extra_froms,
- dialect_hints, **kw)
+ dialect_hints,
+ **kw
+ )
if extra_from_text:
text += " " + extra_from_text
@@ -2276,10 +2567,12 @@ class SQLCompiler(Compiled):
if limit_clause:
text += " " + limit_clause
- if (self.returning or update_stmt._returning) and \
- not self.returning_precedes_values:
+ if (
+ self.returning or update_stmt._returning
+ ) and not self.returning_precedes_values:
text += " " + self.returning_clause(
- update_stmt, self.returning or update_stmt._returning)
+ update_stmt, self.returning or update_stmt._returning
+ )
if self.ctes and toplevel:
text = self._render_cte_clause() + text
@@ -2295,9 +2588,9 @@ class SQLCompiler(Compiled):
def _key_getters_for_crud_column(self):
return crud._key_getters_for_crud_column(self, self.statement)
- def delete_extra_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints, **kw):
+ def delete_extra_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
"""Provide a hook to override the generation of an
DELETE..FROM clause.
@@ -2308,10 +2601,10 @@ class SQLCompiler(Compiled):
"""
raise NotImplementedError(
"This backend does not support multiple-table "
- "criteria within DELETE")
+ "criteria within DELETE"
+ )
- def delete_table_clause(self, delete_stmt, from_table,
- extra_froms):
+ def delete_table_clause(self, delete_stmt, from_table, extra_froms):
return from_table._compiler_dispatch(self, asfrom=True, iscrud=True)
def visit_delete(self, delete_stmt, asfrom=False, **kw):
@@ -2322,23 +2615,30 @@ class SQLCompiler(Compiled):
extra_froms = delete_stmt._extra_froms
correlate_froms = {delete_stmt.table}.union(extra_froms)
- self.stack.append({'correlate_froms': correlate_froms,
- "asfrom_froms": correlate_froms,
- "selectable": delete_stmt})
+ self.stack.append(
+ {
+ "correlate_froms": correlate_froms,
+ "asfrom_froms": correlate_froms,
+ "selectable": delete_stmt,
+ }
+ )
text = "DELETE "
if delete_stmt._prefixes:
- text += self._generate_prefixes(delete_stmt,
- delete_stmt._prefixes, **kw)
+ text += self._generate_prefixes(
+ delete_stmt, delete_stmt._prefixes, **kw
+ )
text += "FROM "
- table_text = self.delete_table_clause(delete_stmt, delete_stmt.table,
- extra_froms)
+ table_text = self.delete_table_clause(
+ delete_stmt, delete_stmt.table, extra_froms
+ )
if delete_stmt._hints:
dialect_hints, table_text = self._setup_crud_hints(
- delete_stmt, table_text)
+ delete_stmt, table_text
+ )
else:
dialect_hints = None
@@ -2347,14 +2647,17 @@ class SQLCompiler(Compiled):
if delete_stmt._returning:
if self.returning_precedes_values:
text += " " + self.returning_clause(
- delete_stmt, delete_stmt._returning)
+ delete_stmt, delete_stmt._returning
+ )
if extra_froms:
extra_from_text = self.delete_extra_from_clause(
delete_stmt,
delete_stmt.table,
extra_froms,
- dialect_hints, **kw)
+ dialect_hints,
+ **kw
+ )
if extra_from_text:
text += " " + extra_from_text
@@ -2365,7 +2668,8 @@ class SQLCompiler(Compiled):
if delete_stmt._returning and not self.returning_precedes_values:
text += " " + self.returning_clause(
- delete_stmt, delete_stmt._returning)
+ delete_stmt, delete_stmt._returning
+ )
if self.ctes and toplevel:
text = self._render_cte_clause() + text
@@ -2381,12 +2685,14 @@ class SQLCompiler(Compiled):
return "SAVEPOINT %s" % self.preparer.format_savepoint(savepoint_stmt)
def visit_rollback_to_savepoint(self, savepoint_stmt):
- return "ROLLBACK TO SAVEPOINT %s" % \
- self.preparer.format_savepoint(savepoint_stmt)
+ return "ROLLBACK TO SAVEPOINT %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
def visit_release_savepoint(self, savepoint_stmt):
- return "RELEASE SAVEPOINT %s" % \
- self.preparer.format_savepoint(savepoint_stmt)
+ return "RELEASE SAVEPOINT %s" % self.preparer.format_savepoint(
+ savepoint_stmt
+ )
class StrSQLCompiler(SQLCompiler):
@@ -2403,7 +2709,7 @@ class StrSQLCompiler(SQLCompiler):
def visit_getitem_binary(self, binary, operator, **kw):
return "%s[%s]" % (
self.process(binary.left, **kw),
- self.process(binary.right, **kw)
+ self.process(binary.right, **kw),
)
def visit_json_getitem_op_binary(self, binary, operator, **kw):
@@ -2421,29 +2727,26 @@ class StrSQLCompiler(SQLCompiler):
for c in elements._select_iterables(returning_cols)
]
- return 'RETURNING ' + ', '.join(columns)
+ return "RETURNING " + ", ".join(columns)
- def update_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
- return "FROM " + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in extra_froms)
+ def update_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ return "FROM " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
- def delete_extra_from_clause(self, update_stmt,
- from_table, extra_froms,
- from_hints,
- **kw):
- return ', ' + ', '.join(
- t._compiler_dispatch(self, asfrom=True,
- fromhints=from_hints, **kw)
- for t in extra_froms)
+ def delete_extra_from_clause(
+ self, update_stmt, from_table, extra_froms, from_hints, **kw
+ ):
+ return ", " + ", ".join(
+ t._compiler_dispatch(self, asfrom=True, fromhints=from_hints, **kw)
+ for t in extra_froms
+ )
class DDLCompiler(Compiled):
-
@util.memoized_property
def sql_compiler(self):
return self.dialect.statement_compiler(self.dialect, None)
@@ -2464,13 +2767,13 @@ class DDLCompiler(Compiled):
preparer = self.preparer
path = preparer.format_table_seq(ddl.target)
if len(path) == 1:
- table, sch = path[0], ''
+ table, sch = path[0], ""
else:
table, sch = path[-1], path[0]
- context.setdefault('table', table)
- context.setdefault('schema', sch)
- context.setdefault('fullname', preparer.format_table(ddl.target))
+ context.setdefault("table", table)
+ context.setdefault("schema", sch)
+ context.setdefault("fullname", preparer.format_table(ddl.target))
return self.sql_compiler.post_process_text(ddl.statement % context)
@@ -2507,9 +2810,9 @@ class DDLCompiler(Compiled):
for create_column in create.columns:
column = create_column.element
try:
- processed = self.process(create_column,
- first_pk=column.primary_key
- and not first_pk)
+ processed = self.process(
+ create_column, first_pk=column.primary_key and not first_pk
+ )
if processed is not None:
text += separator
separator = ", \n"
@@ -2519,13 +2822,15 @@ class DDLCompiler(Compiled):
except exc.CompileError as ce:
util.raise_from_cause(
exc.CompileError(
- util.u("(in table '%s', column '%s'): %s") %
- (table.description, column.name, ce.args[0])
- ))
+ util.u("(in table '%s', column '%s'): %s")
+ % (table.description, column.name, ce.args[0])
+ )
+ )
const = self.create_table_constraints(
- table, _include_foreign_key_constraints= # noqa
- create.include_foreign_key_constraints)
+ table,
+ _include_foreign_key_constraints=create.include_foreign_key_constraints, # noqa
+ )
if const:
text += separator + "\t" + const
@@ -2538,20 +2843,18 @@ class DDLCompiler(Compiled):
if column.system:
return None
- text = self.get_column_specification(
- column,
- first_pk=first_pk
+ text = self.get_column_specification(column, first_pk=first_pk)
+ const = " ".join(
+ self.process(constraint) for constraint in column.constraints
)
- const = " ".join(self.process(constraint)
- for constraint in column.constraints)
if const:
text += " " + const
return text
def create_table_constraints(
- self, table,
- _include_foreign_key_constraints=None):
+ self, table, _include_foreign_key_constraints=None
+ ):
# On some DB order is significant: visit PK first, then the
# other constraints (engine.ReflectionTest.testbasic failed on FB2)
@@ -2565,21 +2868,29 @@ class DDLCompiler(Compiled):
else:
omit_fkcs = set()
- constraints.extend([c for c in table._sorted_constraints
- if c is not table.primary_key and
- c not in omit_fkcs])
+ constraints.extend(
+ [
+ c
+ for c in table._sorted_constraints
+ if c is not table.primary_key and c not in omit_fkcs
+ ]
+ )
return ", \n\t".join(
- p for p in
- (self.process(constraint)
+ p
+ for p in (
+ self.process(constraint)
for constraint in constraints
if (
- constraint._create_rule is None or
- constraint._create_rule(self))
+ constraint._create_rule is None
+ or constraint._create_rule(self)
+ )
and (
- not self.dialect.supports_alter or
- not getattr(constraint, 'use_alter', False)
- )) if p is not None
+ not self.dialect.supports_alter
+ or not getattr(constraint, "use_alter", False)
+ )
+ )
+ if p is not None
)
def visit_drop_table(self, drop):
@@ -2590,34 +2901,38 @@ class DDLCompiler(Compiled):
def _verify_index_table(self, index):
if index.table is None:
- raise exc.CompileError("Index '%s' is not associated "
- "with any table." % index.name)
+ raise exc.CompileError(
+ "Index '%s' is not associated " "with any table." % index.name
+ )
- def visit_create_index(self, create, include_schema=False,
- include_table_schema=True):
+ def visit_create_index(
+ self, create, include_schema=False, include_table_schema=True
+ ):
index = create.element
self._verify_index_table(index)
preparer = self.preparer
text = "CREATE "
if index.unique:
text += "UNIQUE "
- text += "INDEX %s ON %s (%s)" \
- % (
- self._prepared_index_name(index,
- include_schema=include_schema),
- preparer.format_table(index.table,
- use_schema=include_table_schema),
- ', '.join(
- self.sql_compiler.process(
- expr, include_table=False, literal_binds=True) for
- expr in index.expressions)
- )
+ text += "INDEX %s ON %s (%s)" % (
+ self._prepared_index_name(index, include_schema=include_schema),
+ preparer.format_table(
+ index.table, use_schema=include_table_schema
+ ),
+ ", ".join(
+ self.sql_compiler.process(
+ expr, include_table=False, literal_binds=True
+ )
+ for expr in index.expressions
+ ),
+ )
return text
def visit_drop_index(self, drop):
index = drop.element
return "\nDROP INDEX " + self._prepared_index_name(
- index, include_schema=True)
+ index, include_schema=True
+ )
def _prepared_index_name(self, index, include_schema=False):
if index.table is not None:
@@ -2638,35 +2953,41 @@ class DDLCompiler(Compiled):
def visit_add_constraint(self, create):
return "ALTER TABLE %s ADD %s" % (
self.preparer.format_table(create.element.table),
- self.process(create.element)
+ self.process(create.element),
)
def visit_set_table_comment(self, create):
return "COMMENT ON TABLE %s IS %s" % (
self.preparer.format_table(create.element),
self.sql_compiler.render_literal_value(
- create.element.comment, sqltypes.String())
+ create.element.comment, sqltypes.String()
+ ),
)
def visit_drop_table_comment(self, drop):
- return "COMMENT ON TABLE %s IS NULL" % \
- self.preparer.format_table(drop.element)
+ return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
+ drop.element
+ )
def visit_set_column_comment(self, create):
return "COMMENT ON COLUMN %s IS %s" % (
self.preparer.format_column(
- create.element, use_table=True, use_schema=True),
+ create.element, use_table=True, use_schema=True
+ ),
self.sql_compiler.render_literal_value(
- create.element.comment, sqltypes.String())
+ create.element.comment, sqltypes.String()
+ ),
)
def visit_drop_column_comment(self, drop):
- return "COMMENT ON COLUMN %s IS NULL" % \
- self.preparer.format_column(drop.element, use_table=True)
+ return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column(
+ drop.element, use_table=True
+ )
def visit_create_sequence(self, create):
- text = "CREATE SEQUENCE %s" % \
- self.preparer.format_sequence(create.element)
+ text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(
+ create.element
+ )
if create.element.increment is not None:
text += " INCREMENT BY %d" % create.element.increment
if create.element.start is not None:
@@ -2688,8 +3009,7 @@ class DDLCompiler(Compiled):
return text
def visit_drop_sequence(self, drop):
- return "DROP SEQUENCE %s" % \
- self.preparer.format_sequence(drop.element)
+ return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
def visit_drop_constraint(self, drop):
constraint = drop.element
@@ -2701,17 +3021,22 @@ class DDLCompiler(Compiled):
if formatted_name is None:
raise exc.CompileError(
"Can't emit DROP CONSTRAINT for constraint %r; "
- "it has no name" % drop.element)
+ "it has no name" % drop.element
+ )
return "ALTER TABLE %s DROP CONSTRAINT %s%s" % (
self.preparer.format_table(drop.element.table),
formatted_name,
- drop.cascade and " CASCADE" or ""
+ drop.cascade and " CASCADE" or "",
)
def get_column_specification(self, column, **kwargs):
- colspec = self.preparer.format_column(column) + " " + \
- self.dialect.type_compiler.process(
- column.type, type_expression=column)
+ colspec = (
+ self.preparer.format_column(column)
+ + " "
+ + self.dialect.type_compiler.process(
+ column.type, type_expression=column
+ )
+ )
default = self.get_column_default_string(column)
if default is not None:
colspec += " DEFAULT " + default
@@ -2721,19 +3046,21 @@ class DDLCompiler(Compiled):
return colspec
def create_table_suffix(self, table):
- return ''
+ return ""
def post_create_table(self, table):
- return ''
+ return ""
def get_column_default_string(self, column):
if isinstance(column.server_default, schema.DefaultClause):
if isinstance(column.server_default.arg, util.string_types):
return self.sql_compiler.render_literal_value(
- column.server_default.arg, sqltypes.STRINGTYPE)
+ column.server_default.arg, sqltypes.STRINGTYPE
+ )
else:
return self.sql_compiler.process(
- column.server_default.arg, literal_binds=True)
+ column.server_default.arg, literal_binds=True
+ )
else:
return None
@@ -2743,9 +3070,9 @@ class DDLCompiler(Compiled):
formatted_name = self.preparer.format_constraint(constraint)
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
- text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext,
- include_table=False,
- literal_binds=True)
+ text += "CHECK (%s)" % self.sql_compiler.process(
+ constraint.sqltext, include_table=False, literal_binds=True
+ )
text += self.define_constraint_deferrability(constraint)
return text
@@ -2755,25 +3082,29 @@ class DDLCompiler(Compiled):
formatted_name = self.preparer.format_constraint(constraint)
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
- text += "CHECK (%s)" % self.sql_compiler.process(constraint.sqltext,
- include_table=False,
- literal_binds=True)
+ text += "CHECK (%s)" % self.sql_compiler.process(
+ constraint.sqltext, include_table=False, literal_binds=True
+ )
text += self.define_constraint_deferrability(constraint)
return text
def visit_primary_key_constraint(self, constraint):
if len(constraint) == 0:
- return ''
+ return ""
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
if formatted_name is not None:
text += "CONSTRAINT %s " % formatted_name
text += "PRIMARY KEY "
- text += "(%s)" % ', '.join(self.preparer.quote(c.name)
- for c in (constraint.columns_autoinc_first
- if constraint._implicit_generated
- else constraint.columns))
+ text += "(%s)" % ", ".join(
+ self.preparer.quote(c.name)
+ for c in (
+ constraint.columns_autoinc_first
+ if constraint._implicit_generated
+ else constraint.columns
+ )
+ )
text += self.define_constraint_deferrability(constraint)
return text
@@ -2786,12 +3117,15 @@ class DDLCompiler(Compiled):
text += "CONSTRAINT %s " % formatted_name
remote_table = list(constraint.elements)[0].column.table
text += "FOREIGN KEY(%s) REFERENCES %s (%s)" % (
- ', '.join(preparer.quote(f.parent.name)
- for f in constraint.elements),
+ ", ".join(
+ preparer.quote(f.parent.name) for f in constraint.elements
+ ),
self.define_constraint_remote_table(
- constraint, remote_table, preparer),
- ', '.join(preparer.quote(f.column.name)
- for f in constraint.elements)
+ constraint, remote_table, preparer
+ ),
+ ", ".join(
+ preparer.quote(f.column.name) for f in constraint.elements
+ ),
)
text += self.define_constraint_match(constraint)
text += self.define_constraint_cascades(constraint)
@@ -2805,14 +3139,14 @@ class DDLCompiler(Compiled):
def visit_unique_constraint(self, constraint):
if len(constraint) == 0:
- return ''
+ return ""
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
text += "CONSTRAINT %s " % formatted_name
text += "UNIQUE (%s)" % (
- ', '.join(self.preparer.quote(c.name)
- for c in constraint))
+ ", ".join(self.preparer.quote(c.name) for c in constraint)
+ )
text += self.define_constraint_deferrability(constraint)
return text
@@ -2843,7 +3177,6 @@ class DDLCompiler(Compiled):
class GenericTypeCompiler(TypeCompiler):
-
def visit_FLOAT(self, type_, **kw):
return "FLOAT"
@@ -2854,23 +3187,23 @@ class GenericTypeCompiler(TypeCompiler):
if type_.precision is None:
return "NUMERIC"
elif type_.scale is None:
- return "NUMERIC(%(precision)s)" % \
- {'precision': type_.precision}
+ return "NUMERIC(%(precision)s)" % {"precision": type_.precision}
else:
- return "NUMERIC(%(precision)s, %(scale)s)" % \
- {'precision': type_.precision,
- 'scale': type_.scale}
+ return "NUMERIC(%(precision)s, %(scale)s)" % {
+ "precision": type_.precision,
+ "scale": type_.scale,
+ }
def visit_DECIMAL(self, type_, **kw):
if type_.precision is None:
return "DECIMAL"
elif type_.scale is None:
- return "DECIMAL(%(precision)s)" % \
- {'precision': type_.precision}
+ return "DECIMAL(%(precision)s)" % {"precision": type_.precision}
else:
- return "DECIMAL(%(precision)s, %(scale)s)" % \
- {'precision': type_.precision,
- 'scale': type_.scale}
+ return "DECIMAL(%(precision)s, %(scale)s)" % {
+ "precision": type_.precision,
+ "scale": type_.scale,
+ }
def visit_INTEGER(self, type_, **kw):
return "INTEGER"
@@ -2882,7 +3215,7 @@ class GenericTypeCompiler(TypeCompiler):
return "BIGINT"
def visit_TIMESTAMP(self, type_, **kw):
- return 'TIMESTAMP'
+ return "TIMESTAMP"
def visit_DATETIME(self, type_, **kw):
return "DATETIME"
@@ -2984,9 +3317,11 @@ class GenericTypeCompiler(TypeCompiler):
return self.visit_VARCHAR(type_, **kw)
def visit_null(self, type_, **kw):
- raise exc.CompileError("Can't generate DDL for %r; "
- "did you forget to specify a "
- "type on this Column?" % type_)
+ raise exc.CompileError(
+ "Can't generate DDL for %r; "
+ "did you forget to specify a "
+ "type on this Column?" % type_
+ )
def visit_type_decorator(self, type_, **kw):
return self.process(type_.type_engine(self.dialect), **kw)
@@ -3018,9 +3353,15 @@ class IdentifierPreparer(object):
schema_for_object = schema._schema_getter(None)
- def __init__(self, dialect, initial_quote='"',
- final_quote=None, escape_quote='"',
- quote_case_sensitive_collations=True, omit_schema=False):
+ def __init__(
+ self,
+ dialect,
+ initial_quote='"',
+ final_quote=None,
+ escape_quote='"',
+ quote_case_sensitive_collations=True,
+ omit_schema=False,
+ ):
"""Construct a new ``IdentifierPreparer`` object.
initial_quote
@@ -3043,7 +3384,10 @@ class IdentifierPreparer(object):
self.omit_schema = omit_schema
self.quote_case_sensitive_collations = quote_case_sensitive_collations
self._strings = {}
- self._double_percents = self.dialect.paramstyle in ('format', 'pyformat')
+ self._double_percents = self.dialect.paramstyle in (
+ "format",
+ "pyformat",
+ )
def _with_schema_translate(self, schema_translate_map):
prep = self.__class__.__new__(self.__class__)
@@ -3060,7 +3404,7 @@ class IdentifierPreparer(object):
value = value.replace(self.escape_quote, self.escape_to_quote)
if self._double_percents:
- value = value.replace('%', '%%')
+ value = value.replace("%", "%%")
return value
def _unescape_identifier(self, value):
@@ -3079,17 +3423,21 @@ class IdentifierPreparer(object):
quoting behavior.
"""
- return self.initial_quote + \
- self._escape_identifier(value) + \
- self.final_quote
+ return (
+ self.initial_quote
+ + self._escape_identifier(value)
+ + self.final_quote
+ )
def _requires_quotes(self, value):
"""Return True if the given identifier requires quoting."""
lc_value = value.lower()
- return (lc_value in self.reserved_words
- or value[0] in self.illegal_initial_characters
- or not self.legal_characters.match(util.text_type(value))
- or (lc_value != value))
+ return (
+ lc_value in self.reserved_words
+ or value[0] in self.illegal_initial_characters
+ or not self.legal_characters.match(util.text_type(value))
+ or (lc_value != value)
+ )
def quote_schema(self, schema, force=None):
"""Conditionally quote a schema.
@@ -3135,8 +3483,11 @@ class IdentifierPreparer(object):
effective_schema = self.schema_for_object(sequence)
- if (not self.omit_schema and use_schema and
- effective_schema is not None):
+ if (
+ not self.omit_schema
+ and use_schema
+ and effective_schema is not None
+ ):
name = self.quote_schema(effective_schema) + "." + name
return name
@@ -3159,7 +3510,8 @@ class IdentifierPreparer(object):
def format_constraint(self, naming, constraint):
if isinstance(constraint.name, elements._defer_name):
name = naming._constraint_name_for_table(
- constraint, constraint.table)
+ constraint, constraint.table
+ )
if name is None:
if isinstance(constraint.name, elements._defer_none_name):
@@ -3170,14 +3522,15 @@ class IdentifierPreparer(object):
name = constraint.name
if isinstance(name, elements._truncated_label):
- if constraint.__visit_name__ == 'index':
- max_ = self.dialect.max_index_name_length or \
- self.dialect.max_identifier_length
+ if constraint.__visit_name__ == "index":
+ max_ = (
+ self.dialect.max_index_name_length
+ or self.dialect.max_identifier_length
+ )
else:
max_ = self.dialect.max_identifier_length
if len(name) > max_:
- name = name[0:max_ - 8] + \
- "_" + util.md5_hex(name)[-4:]
+ name = name[0 : max_ - 8] + "_" + util.md5_hex(name)[-4:]
else:
self.dialect.validate_identifier(name)
@@ -3195,8 +3548,7 @@ class IdentifierPreparer(object):
effective_schema = self.schema_for_object(table)
- if not self.omit_schema and use_schema \
- and effective_schema:
+ if not self.omit_schema and use_schema and effective_schema:
result = self.quote_schema(effective_schema) + "." + result
return result
@@ -3205,17 +3557,27 @@ class IdentifierPreparer(object):
return self.quote(name, quote)
- def format_column(self, column, use_table=False,
- name=None, table_name=None, use_schema=False):
+ def format_column(
+ self,
+ column,
+ use_table=False,
+ name=None,
+ table_name=None,
+ use_schema=False,
+ ):
"""Prepare a quoted column name."""
if name is None:
name = column.name
- if not getattr(column, 'is_literal', False):
+ if not getattr(column, "is_literal", False):
if use_table:
- return self.format_table(
- column.table, use_schema=use_schema,
- name=table_name) + "." + self.quote(name)
+ return (
+ self.format_table(
+ column.table, use_schema=use_schema, name=table_name
+ )
+ + "."
+ + self.quote(name)
+ )
else:
return self.quote(name)
else:
@@ -3223,9 +3585,13 @@ class IdentifierPreparer(object):
# which shouldn't get quoted
if use_table:
- return self.format_table(
- column.table, use_schema=use_schema,
- name=table_name) + '.' + name
+ return (
+ self.format_table(
+ column.table, use_schema=use_schema, name=table_name
+ )
+ + "."
+ + name
+ )
else:
return name
@@ -3238,31 +3604,37 @@ class IdentifierPreparer(object):
effective_schema = self.schema_for_object(table)
- if not self.omit_schema and use_schema and \
- effective_schema:
- return (self.quote_schema(effective_schema),
- self.format_table(table, use_schema=False))
+ if not self.omit_schema and use_schema and effective_schema:
+ return (
+ self.quote_schema(effective_schema),
+ self.format_table(table, use_schema=False),
+ )
else:
- return (self.format_table(table, use_schema=False), )
+ return (self.format_table(table, use_schema=False),)
@util.memoized_property
def _r_identifiers(self):
- initial, final, escaped_final = \
- [re.escape(s) for s in
- (self.initial_quote, self.final_quote,
- self._escape_identifier(self.final_quote))]
+ initial, final, escaped_final = [
+ re.escape(s)
+ for s in (
+ self.initial_quote,
+ self.final_quote,
+ self._escape_identifier(self.final_quote),
+ )
+ ]
r = re.compile(
- r'(?:'
- r'(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s'
- r'|([^\.]+))(?=\.|$))+' %
- {'initial': initial,
- 'final': final,
- 'escaped': escaped_final})
+ r"(?:"
+ r"(?:%(initial)s((?:%(escaped)s|[^%(final)s])+)%(final)s"
+ r"|([^\.]+))(?=\.|$))+"
+ % {"initial": initial, "final": final, "escaped": escaped_final}
+ )
return r
def unformat_identifiers(self, identifiers):
"""Unpack 'schema.table.column'-like strings into components."""
r = self._r_identifiers
- return [self._unescape_identifier(i)
- for i in [a or b for a, b in r.findall(identifiers)]]
+ return [
+ self._unescape_identifier(i)
+ for i in [a or b for a, b in r.findall(identifiers)]
+ ]
diff --git a/lib/sqlalchemy/sql/crud.py b/lib/sqlalchemy/sql/crud.py
index 999d48a55..602b91a25 100644
--- a/lib/sqlalchemy/sql/crud.py
+++ b/lib/sqlalchemy/sql/crud.py
@@ -15,7 +15,9 @@ from . import dml
from . import elements
import operator
-REQUIRED = util.symbol('REQUIRED', """
+REQUIRED = util.symbol(
+ "REQUIRED",
+ """
Placeholder for the value within a :class:`.BindParameter`
which is required to be present when the statement is passed
to :meth:`.Connection.execute`.
@@ -24,11 +26,12 @@ This symbol is typically used when a :func:`.expression.insert`
or :func:`.expression.update` statement is compiled without parameter
values present.
-""")
+""",
+)
-ISINSERT = util.symbol('ISINSERT')
-ISUPDATE = util.symbol('ISUPDATE')
-ISDELETE = util.symbol('ISDELETE')
+ISINSERT = util.symbol("ISINSERT")
+ISUPDATE = util.symbol("ISUPDATE")
+ISDELETE = util.symbol("ISDELETE")
def _setup_crud_params(compiler, stmt, local_stmt_type, **kw):
@@ -82,8 +85,7 @@ def _get_crud_params(compiler, stmt, **kw):
# compiled params - return binds for all columns
if compiler.column_keys is None and stmt.parameters is None:
return [
- (c, _create_bind_param(
- compiler, c, None, required=True))
+ (c, _create_bind_param(compiler, c, None, required=True))
for c in stmt.table.columns
]
@@ -95,26 +97,28 @@ def _get_crud_params(compiler, stmt, **kw):
# getters - these are normally just column.key,
# but in the case of mysql multi-table update, the rules for
# .key must conditionally take tablename into account
- _column_as_key, _getattr_col_key, _col_bind_name = \
- _key_getters_for_crud_column(compiler, stmt)
+ _column_as_key, _getattr_col_key, _col_bind_name = _key_getters_for_crud_column(
+ compiler, stmt
+ )
# if we have statement parameters - set defaults in the
# compiled params
if compiler.column_keys is None:
parameters = {}
else:
- parameters = dict((_column_as_key(key), REQUIRED)
- for key in compiler.column_keys
- if not stmt_parameters or
- key not in stmt_parameters)
+ parameters = dict(
+ (_column_as_key(key), REQUIRED)
+ for key in compiler.column_keys
+ if not stmt_parameters or key not in stmt_parameters
+ )
# create a list of column assignment clauses as tuples
values = []
if stmt_parameters is not None:
_get_stmt_parameters_params(
- compiler,
- parameters, stmt_parameters, _column_as_key, values, kw)
+ compiler, parameters, stmt_parameters, _column_as_key, values, kw
+ )
check_columns = {}
@@ -122,28 +126,51 @@ def _get_crud_params(compiler, stmt, **kw):
# statements
if compiler.isupdate and stmt._extra_froms and stmt_parameters:
_get_multitable_params(
- compiler, stmt, stmt_parameters, check_columns,
- _col_bind_name, _getattr_col_key, values, kw)
+ compiler,
+ stmt,
+ stmt_parameters,
+ check_columns,
+ _col_bind_name,
+ _getattr_col_key,
+ values,
+ kw,
+ )
if compiler.isinsert and stmt.select_names:
_scan_insert_from_select_cols(
- compiler, stmt, parameters,
- _getattr_col_key, _column_as_key,
- _col_bind_name, check_columns, values, kw)
+ compiler,
+ stmt,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+ )
else:
_scan_cols(
- compiler, stmt, parameters,
- _getattr_col_key, _column_as_key,
- _col_bind_name, check_columns, values, kw)
+ compiler,
+ stmt,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+ )
if parameters and stmt_parameters:
- check = set(parameters).intersection(
- _column_as_key(k) for k in stmt_parameters
- ).difference(check_columns)
+ check = (
+ set(parameters)
+ .intersection(_column_as_key(k) for k in stmt_parameters)
+ .difference(check_columns)
+ )
if check:
raise exc.CompileError(
- "Unconsumed column names: %s" %
- (", ".join("%s" % c for c in check))
+ "Unconsumed column names: %s"
+ % (", ".join("%s" % c for c in check))
)
if stmt._has_multi_parameters:
@@ -153,12 +180,13 @@ def _get_crud_params(compiler, stmt, **kw):
def _create_bind_param(
- compiler, col, value, process=True,
- required=False, name=None, **kw):
+ compiler, col, value, process=True, required=False, name=None, **kw
+):
if name is None:
name = col.key
bindparam = elements.BindParameter(
- name, value, type_=col.type, required=required)
+ name, value, type_=col.type, required=required
+ )
bindparam._is_crud = True
if process:
bindparam = bindparam._compiler_dispatch(compiler, **kw)
@@ -177,7 +205,7 @@ def _key_getters_for_crud_column(compiler, stmt):
def _column_as_key(key):
str_key = elements._column_as_key(key)
- if hasattr(key, 'table') and key.table in _et:
+ if hasattr(key, "table") and key.table in _et:
return (key.table.name, str_key)
else:
return str_key
@@ -202,15 +230,22 @@ def _key_getters_for_crud_column(compiler, stmt):
def _scan_insert_from_select_cols(
- compiler, stmt, parameters, _getattr_col_key,
- _column_as_key, _col_bind_name, check_columns, values, kw):
-
- need_pks, implicit_returning, \
- implicit_return_defaults, postfetch_lastrowid = \
- _get_returning_modifiers(compiler, stmt)
+ compiler,
+ stmt,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+):
+
+ need_pks, implicit_returning, implicit_return_defaults, postfetch_lastrowid = _get_returning_modifiers(
+ compiler, stmt
+ )
- cols = [stmt.table.c[_column_as_key(name)]
- for name in stmt.select_names]
+ cols = [stmt.table.c[_column_as_key(name)] for name in stmt.select_names]
compiler._insert_from_select = stmt.select
@@ -228,32 +263,39 @@ def _scan_insert_from_select_cols(
values.append((c, None))
else:
_append_param_insert_select_hasdefault(
- compiler, stmt, c, add_select_cols, kw)
+ compiler, stmt, c, add_select_cols, kw
+ )
if add_select_cols:
values.extend(add_select_cols)
compiler._insert_from_select = compiler._insert_from_select._generate()
- compiler._insert_from_select._raw_columns = \
- tuple(compiler._insert_from_select._raw_columns) + tuple(
- expr for col, expr in add_select_cols)
+ compiler._insert_from_select._raw_columns = tuple(
+ compiler._insert_from_select._raw_columns
+ ) + tuple(expr for col, expr in add_select_cols)
def _scan_cols(
- compiler, stmt, parameters, _getattr_col_key,
- _column_as_key, _col_bind_name, check_columns, values, kw):
-
- need_pks, implicit_returning, \
- implicit_return_defaults, postfetch_lastrowid = \
- _get_returning_modifiers(compiler, stmt)
+ compiler,
+ stmt,
+ parameters,
+ _getattr_col_key,
+ _column_as_key,
+ _col_bind_name,
+ check_columns,
+ values,
+ kw,
+):
+
+ need_pks, implicit_returning, implicit_return_defaults, postfetch_lastrowid = _get_returning_modifiers(
+ compiler, stmt
+ )
if stmt._parameter_ordering:
parameter_ordering = [
_column_as_key(key) for key in stmt._parameter_ordering
]
ordered_keys = set(parameter_ordering)
- cols = [
- stmt.table.c[key] for key in parameter_ordering
- ] + [
+ cols = [stmt.table.c[key] for key in parameter_ordering] + [
c for c in stmt.table.c if c.key not in ordered_keys
]
else:
@@ -265,72 +307,95 @@ def _scan_cols(
if col_key in parameters and col_key not in check_columns:
_append_param_parameter(
- compiler, stmt, c, col_key, parameters, _col_bind_name,
- implicit_returning, implicit_return_defaults, values, kw)
+ compiler,
+ stmt,
+ c,
+ col_key,
+ parameters,
+ _col_bind_name,
+ implicit_returning,
+ implicit_return_defaults,
+ values,
+ kw,
+ )
elif compiler.isinsert:
- if c.primary_key and \
- need_pks and \
- (
- implicit_returning or
- not postfetch_lastrowid or
- c is not stmt.table._autoincrement_column
- ):
+ if (
+ c.primary_key
+ and need_pks
+ and (
+ implicit_returning
+ or not postfetch_lastrowid
+ or c is not stmt.table._autoincrement_column
+ )
+ ):
if implicit_returning:
_append_param_insert_pk_returning(
- compiler, stmt, c, values, kw)
+ compiler, stmt, c, values, kw
+ )
else:
_append_param_insert_pk(compiler, stmt, c, values, kw)
elif c.default is not None:
_append_param_insert_hasdefault(
- compiler, stmt, c, implicit_return_defaults,
- values, kw)
+ compiler, stmt, c, implicit_return_defaults, values, kw
+ )
elif c.server_default is not None:
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
+ elif implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
- elif c.primary_key and \
- c is not stmt.table._autoincrement_column and \
- not c.nullable:
+ elif (
+ c.primary_key
+ and c is not stmt.table._autoincrement_column
+ and not c.nullable
+ ):
_warn_pk_with_no_anticipated_value(c)
elif compiler.isupdate:
_append_param_update(
- compiler, stmt, c, implicit_return_defaults, values, kw)
+ compiler, stmt, c, implicit_return_defaults, values, kw
+ )
def _append_param_parameter(
- compiler, stmt, c, col_key, parameters, _col_bind_name,
- implicit_returning, implicit_return_defaults, values, kw):
+ compiler,
+ stmt,
+ c,
+ col_key,
+ parameters,
+ _col_bind_name,
+ implicit_returning,
+ implicit_return_defaults,
+ values,
+ kw,
+):
value = parameters.pop(col_key)
if elements._is_literal(value):
value = _create_bind_param(
- compiler, c, value, required=value is REQUIRED,
+ compiler,
+ c,
+ value,
+ required=value is REQUIRED,
name=_col_bind_name(c)
if not stmt._has_multi_parameters
else "%s_m0" % _col_bind_name(c),
**kw
)
else:
- if isinstance(value, elements.BindParameter) and \
- value.type._isnull:
+ if isinstance(value, elements.BindParameter) and value.type._isnull:
value = value._clone()
value.type = c.type
if c.primary_key and implicit_returning:
compiler.returning.append(c)
value = compiler.process(value.self_group(), **kw)
- elif implicit_return_defaults and \
- c in implicit_return_defaults:
+ elif implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
value = compiler.process(value.self_group(), **kw)
else:
@@ -358,22 +423,20 @@ def _append_param_insert_pk_returning(compiler, stmt, c, values, kw):
"""
if c.default is not None:
if c.default.is_sequence:
- if compiler.dialect.supports_sequences and \
- (not c.default.optional or
- not compiler.dialect.sequences_optional):
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional
+ or not compiler.dialect.sequences_optional
+ ):
proc = compiler.process(c.default, **kw)
values.append((c, proc))
compiler.returning.append(c)
elif c.default.is_clause_element:
values.append(
- (c, compiler.process(
- c.default.arg.self_group(), **kw))
+ (c, compiler.process(c.default.arg.self_group(), **kw))
)
compiler.returning.append(c)
else:
- values.append(
- (c, _create_insert_prefetch_bind_param(compiler, c))
- )
+ values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
elif c is stmt.table._autoincrement_column or c.server_default is not None:
compiler.returning.append(c)
elif not c.nullable:
@@ -405,9 +468,11 @@ class _multiparam_column(elements.ColumnElement):
self.type = original.type
def __eq__(self, other):
- return isinstance(other, _multiparam_column) and \
- other.key == self.key and \
- other.original == self.original
+ return (
+ isinstance(other, _multiparam_column)
+ and other.key == self.key
+ and other.original == self.original
+ )
def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
@@ -416,7 +481,8 @@ def _process_multiparam_default_bind(compiler, stmt, c, index, kw):
raise exc.CompileError(
"INSERT value for column %s is explicitly rendered as a bound"
"parameter in the VALUES clause; "
- "a Python-side value or SQL expression is required" % c)
+ "a Python-side value or SQL expression is required" % c
+ )
elif c.default.is_clause_element:
return compiler.process(c.default.arg.self_group(), **kw)
else:
@@ -440,30 +506,24 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw):
"""
if (
- (
- # column has a Python-side default
- c.default is not None and
- (
- # and it won't be a Sequence
- not c.default.is_sequence or
- compiler.dialect.supports_sequences
- )
- )
- or
- (
- # column is the "autoincrement column"
- c is stmt.table._autoincrement_column and
- (
- # and it's either a "sequence" or a
- # pre-executable "autoincrement" sequence
- compiler.dialect.supports_sequences or
- compiler.dialect.preexecute_autoincrement_sequences
- )
- )
- ):
- values.append(
- (c, _create_insert_prefetch_bind_param(compiler, c))
+ # column has a Python-side default
+ c.default is not None
+ and (
+ # and it won't be a Sequence
+ not c.default.is_sequence
+ or compiler.dialect.supports_sequences
)
+ ) or (
+ # column is the "autoincrement column"
+ c is stmt.table._autoincrement_column
+ and (
+ # and it's either a "sequence" or a
+ # pre-executable "autoincrement" sequence
+ compiler.dialect.supports_sequences
+ or compiler.dialect.preexecute_autoincrement_sequences
+ )
+ ):
+ values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
elif c.default is None and c.server_default is None and not c.nullable:
# no .default, no .server_default, not autoincrement, we have
# no indication this primary key column will have any value
@@ -471,16 +531,16 @@ def _append_param_insert_pk(compiler, stmt, c, values, kw):
def _append_param_insert_hasdefault(
- compiler, stmt, c, implicit_return_defaults, values, kw):
+ compiler, stmt, c, implicit_return_defaults, values, kw
+):
if c.default.is_sequence:
- if compiler.dialect.supports_sequences and \
- (not c.default.optional or
- not compiler.dialect.sequences_optional):
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional or not compiler.dialect.sequences_optional
+ ):
proc = compiler.process(c.default, **kw)
values.append((c, proc))
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
compiler.postfetch.append(c)
@@ -488,25 +548,21 @@ def _append_param_insert_hasdefault(
proc = compiler.process(c.default.arg.self_group(), **kw)
values.append((c, proc))
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
elif not c.primary_key:
# don't add primary key column to postfetch
compiler.postfetch.append(c)
else:
- values.append(
- (c, _create_insert_prefetch_bind_param(compiler, c))
- )
+ values.append((c, _create_insert_prefetch_bind_param(compiler, c)))
-def _append_param_insert_select_hasdefault(
- compiler, stmt, c, values, kw):
+def _append_param_insert_select_hasdefault(compiler, stmt, c, values, kw):
if c.default.is_sequence:
- if compiler.dialect.supports_sequences and \
- (not c.default.optional or
- not compiler.dialect.sequences_optional):
+ if compiler.dialect.supports_sequences and (
+ not c.default.optional or not compiler.dialect.sequences_optional
+ ):
proc = c.default
values.append((c, proc.next_value()))
elif c.default.is_clause_element:
@@ -519,38 +575,43 @@ def _append_param_insert_select_hasdefault(
def _append_param_update(
- compiler, stmt, c, implicit_return_defaults, values, kw):
+ compiler, stmt, c, implicit_return_defaults, values, kw
+):
if c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
- (c, compiler.process(
- c.onupdate.arg.self_group(), **kw))
+ (c, compiler.process(c.onupdate.arg.self_group(), **kw))
)
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
else:
compiler.postfetch.append(c)
else:
- values.append(
- (c, _create_update_prefetch_bind_param(compiler, c))
- )
+ values.append((c, _create_update_prefetch_bind_param(compiler, c)))
elif c.server_onupdate is not None:
- if implicit_return_defaults and \
- c in implicit_return_defaults:
+ if implicit_return_defaults and c in implicit_return_defaults:
compiler.returning.append(c)
else:
compiler.postfetch.append(c)
- elif implicit_return_defaults and \
- stmt._return_defaults is not True and \
- c in implicit_return_defaults:
+ elif (
+ implicit_return_defaults
+ and stmt._return_defaults is not True
+ and c in implicit_return_defaults
+ ):
compiler.returning.append(c)
def _get_multitable_params(
- compiler, stmt, stmt_parameters, check_columns,
- _col_bind_name, _getattr_col_key, values, kw):
+ compiler,
+ stmt,
+ stmt_parameters,
+ check_columns,
+ _col_bind_name,
+ _getattr_col_key,
+ values,
+ kw,
+):
normalized_params = dict(
(elements._clause_element_as_expr(c), param)
@@ -565,8 +626,12 @@ def _get_multitable_params(
value = normalized_params[c]
if elements._is_literal(value):
value = _create_bind_param(
- compiler, c, value, required=value is REQUIRED,
- name=_col_bind_name(c))
+ compiler,
+ c,
+ value,
+ required=value is REQUIRED,
+ name=_col_bind_name(c),
+ )
else:
compiler.postfetch.append(c)
value = compiler.process(value.self_group(), **kw)
@@ -577,20 +642,25 @@ def _get_multitable_params(
for c in t.c:
if c in normalized_params:
continue
- elif (c.onupdate is not None and not
- c.onupdate.is_sequence):
+ elif c.onupdate is not None and not c.onupdate.is_sequence:
if c.onupdate.is_clause_element:
values.append(
- (c, compiler.process(
- c.onupdate.arg.self_group(),
- **kw)
- )
+ (
+ c,
+ compiler.process(
+ c.onupdate.arg.self_group(), **kw
+ ),
+ )
)
compiler.postfetch.append(c)
else:
values.append(
- (c, _create_update_prefetch_bind_param(
- compiler, c, name=_col_bind_name(c)))
+ (
+ c,
+ _create_update_prefetch_bind_param(
+ compiler, c, name=_col_bind_name(c)
+ ),
+ )
)
elif c.server_onupdate is not None:
compiler.postfetch.append(c)
@@ -608,8 +678,11 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw):
if elements._is_literal(row[key]):
new_param = _create_bind_param(
- compiler, col, row[key],
- name="%s_m%d" % (col.key, i + 1), **kw
+ compiler,
+ col,
+ row[key],
+ name="%s_m%d" % (col.key, i + 1),
+ **kw
)
else:
new_param = compiler.process(row[key].self_group(), **kw)
@@ -626,7 +699,8 @@ def _extend_values_for_multiparams(compiler, stmt, values, kw):
def _get_stmt_parameters_params(
- compiler, parameters, stmt_parameters, _column_as_key, values, kw):
+ compiler, parameters, stmt_parameters, _column_as_key, values, kw
+):
for k, v in stmt_parameters.items():
colkey = _column_as_key(k)
if colkey is not None:
@@ -637,8 +711,8 @@ def _get_stmt_parameters_params(
# coercing right side to bound param
if elements._is_literal(v):
v = compiler.process(
- elements.BindParameter(None, v, type_=k.type),
- **kw)
+ elements.BindParameter(None, v, type_=k.type), **kw
+ )
else:
v = compiler.process(v.self_group(), **kw)
@@ -646,22 +720,27 @@ def _get_stmt_parameters_params(
def _get_returning_modifiers(compiler, stmt):
- need_pks = compiler.isinsert and \
- not compiler.inline and \
- not stmt._returning and \
- not stmt._has_multi_parameters
+ need_pks = (
+ compiler.isinsert
+ and not compiler.inline
+ and not stmt._returning
+ and not stmt._has_multi_parameters
+ )
- implicit_returning = need_pks and \
- compiler.dialect.implicit_returning and \
- stmt.table.implicit_returning
+ implicit_returning = (
+ need_pks
+ and compiler.dialect.implicit_returning
+ and stmt.table.implicit_returning
+ )
if compiler.isinsert:
- implicit_return_defaults = (implicit_returning and
- stmt._return_defaults)
+ implicit_return_defaults = implicit_returning and stmt._return_defaults
elif compiler.isupdate:
- implicit_return_defaults = (compiler.dialect.implicit_returning and
- stmt.table.implicit_returning and
- stmt._return_defaults)
+ implicit_return_defaults = (
+ compiler.dialect.implicit_returning
+ and stmt.table.implicit_returning
+ and stmt._return_defaults
+ )
else:
# this line is unused, currently we are always
# isinsert or isupdate
@@ -675,8 +754,12 @@ def _get_returning_modifiers(compiler, stmt):
postfetch_lastrowid = need_pks and compiler.dialect.postfetch_lastrowid
- return need_pks, implicit_returning, \
- implicit_return_defaults, postfetch_lastrowid
+ return (
+ need_pks,
+ implicit_returning,
+ implicit_return_defaults,
+ postfetch_lastrowid,
+ )
def _warn_pk_with_no_anticipated_value(c):
@@ -687,8 +770,8 @@ def _warn_pk_with_no_anticipated_value(c):
"nor does it indicate 'autoincrement=True' or 'nullable=True', "
"and no explicit value is passed. "
"Primary key columns typically may not store NULL."
- %
- (c.table.fullname, c.name, c.table.fullname))
+ % (c.table.fullname, c.name, c.table.fullname)
+ )
if len(c.table.primary_key) > 1:
msg += (
" Note that as of SQLAlchemy 1.1, 'autoincrement=True' must be "
@@ -696,5 +779,6 @@ def _warn_pk_with_no_anticipated_value(c):
"keys if AUTO_INCREMENT/SERIAL/IDENTITY "
"behavior is expected for one of the columns in the primary key. "
"CREATE TABLE statements are impacted by this change as well on "
- "most backends.")
+ "most backends."
+ )
util.warn(msg)
diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py
index 91e93efe7..f21b3d7f0 100644
--- a/lib/sqlalchemy/sql/ddl.py
+++ b/lib/sqlalchemy/sql/ddl.py
@@ -56,8 +56,9 @@ class DDLElement(Executable, _DDLCompiles):
"""
- _execution_options = Executable.\
- _execution_options.union({'autocommit': True})
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": True}
+ )
target = None
on = None
@@ -95,11 +96,13 @@ class DDLElement(Executable, _DDLCompiles):
if self._should_execute(target, bind):
return bind.execute(self.against(target))
else:
- bind.engine.logger.info(
- "DDL execution skipped, criteria not met.")
+ bind.engine.logger.info("DDL execution skipped, criteria not met.")
- @util.deprecated("0.7", "See :class:`.DDLEvents`, as well as "
- ":meth:`.DDLElement.execute_if`.")
+ @util.deprecated(
+ "0.7",
+ "See :class:`.DDLEvents`, as well as "
+ ":meth:`.DDLElement.execute_if`.",
+ )
def execute_at(self, event_name, target):
"""Link execution of this DDL to the DDL lifecycle of a SchemaItem.
@@ -129,11 +132,12 @@ class DDLElement(Executable, _DDLCompiles):
"""
def call_event(target, connection, **kw):
- if self._should_execute_deprecated(event_name,
- target, connection, **kw):
+ if self._should_execute_deprecated(
+ event_name, target, connection, **kw
+ ):
return connection.execute(self.against(target))
- event.listen(target, "" + event_name.replace('-', '_'), call_event)
+ event.listen(target, "" + event_name.replace("-", "_"), call_event)
@_generative
def against(self, target):
@@ -211,8 +215,9 @@ class DDLElement(Executable, _DDLCompiles):
self.state = state
def _should_execute(self, target, bind, **kw):
- if self.on is not None and \
- not self._should_execute_deprecated(None, target, bind, **kw):
+ if self.on is not None and not self._should_execute_deprecated(
+ None, target, bind, **kw
+ ):
return False
if isinstance(self.dialect, util.string_types):
@@ -221,9 +226,9 @@ class DDLElement(Executable, _DDLCompiles):
elif isinstance(self.dialect, (tuple, list, set)):
if bind.engine.name not in self.dialect:
return False
- if (self.callable_ is not None and
- not self.callable_(self, target, bind,
- state=self.state, **kw)):
+ if self.callable_ is not None and not self.callable_(
+ self, target, bind, state=self.state, **kw
+ ):
return False
return True
@@ -245,13 +250,15 @@ class DDLElement(Executable, _DDLCompiles):
return bind.execute(self.against(target))
def _check_ddl_on(self, on):
- if (on is not None and
- (not isinstance(on, util.string_types + (tuple, list, set)) and
- not util.callable(on))):
+ if on is not None and (
+ not isinstance(on, util.string_types + (tuple, list, set))
+ and not util.callable(on)
+ ):
raise exc.ArgumentError(
"Expected the name of a database dialect, a tuple "
"of names, or a callable for "
- "'on' criteria, got type '%s'." % type(on).__name__)
+ "'on' criteria, got type '%s'." % type(on).__name__
+ )
def bind(self):
if self._bind:
@@ -259,6 +266,7 @@ class DDLElement(Executable, _DDLCompiles):
def _set_bind(self, bind):
self._bind = bind
+
bind = property(bind, _set_bind)
def _generate(self):
@@ -375,8 +383,9 @@ class DDL(DDLElement):
if not isinstance(statement, util.string_types):
raise exc.ArgumentError(
- "Expected a string or unicode SQL statement, got '%r'" %
- statement)
+ "Expected a string or unicode SQL statement, got '%r'"
+ % statement
+ )
self.statement = statement
self.context = context or {}
@@ -386,12 +395,18 @@ class DDL(DDLElement):
self._bind = bind
def __repr__(self):
- return '<%s@%s; %s>' % (
- type(self).__name__, id(self),
- ', '.join([repr(self.statement)] +
- ['%s=%r' % (key, getattr(self, key))
- for key in ('on', 'context')
- if getattr(self, key)]))
+ return "<%s@%s; %s>" % (
+ type(self).__name__,
+ id(self),
+ ", ".join(
+ [repr(self.statement)]
+ + [
+ "%s=%r" % (key, getattr(self, key))
+ for key in ("on", "context")
+ if getattr(self, key)
+ ]
+ ),
+ )
class _CreateDropBase(DDLElement):
@@ -464,8 +479,8 @@ class CreateTable(_CreateDropBase):
__visit_name__ = "create_table"
def __init__(
- self, element, on=None, bind=None,
- include_foreign_key_constraints=None):
+ self, element, on=None, bind=None, include_foreign_key_constraints=None
+ ):
"""Create a :class:`.CreateTable` construct.
:param element: a :class:`.Table` that's the subject
@@ -481,9 +496,7 @@ class CreateTable(_CreateDropBase):
"""
super(CreateTable, self).__init__(element, on=on, bind=bind)
- self.columns = [CreateColumn(column)
- for column in element.columns
- ]
+ self.columns = [CreateColumn(column) for column in element.columns]
self.include_foreign_key_constraints = include_foreign_key_constraints
@@ -494,6 +507,7 @@ class _DropView(_CreateDropBase):
This object will eventually be part of a public "view" API.
"""
+
__visit_name__ = "drop_view"
@@ -602,7 +616,8 @@ class CreateColumn(_DDLCompiles):
to support custom column creation styles.
"""
- __visit_name__ = 'create_column'
+
+ __visit_name__ = "create_column"
def __init__(self, element):
self.element = element
@@ -646,7 +661,8 @@ class AddConstraint(_CreateDropBase):
def __init__(self, element, *args, **kw):
super(AddConstraint, self).__init__(element, *args, **kw)
element._create_rule = util.portable_instancemethod(
- self._create_rule_disable)
+ self._create_rule_disable
+ )
class DropConstraint(_CreateDropBase):
@@ -658,7 +674,8 @@ class DropConstraint(_CreateDropBase):
self.cascade = cascade
super(DropConstraint, self).__init__(element, **kw)
element._create_rule = util.portable_instancemethod(
- self._create_rule_disable)
+ self._create_rule_disable
+ )
class SetTableComment(_CreateDropBase):
@@ -691,9 +708,9 @@ class DDLBase(SchemaVisitor):
class SchemaGenerator(DDLBase):
-
- def __init__(self, dialect, connection, checkfirst=False,
- tables=None, **kwargs):
+ def __init__(
+ self, dialect, connection, checkfirst=False, tables=None, **kwargs
+ ):
super(SchemaGenerator, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables
@@ -706,25 +723,22 @@ class SchemaGenerator(DDLBase):
effective_schema = self.connection.schema_for_object(table)
if effective_schema:
self.dialect.validate_identifier(effective_schema)
- return not self.checkfirst or \
- not self.dialect.has_table(self.connection,
- table.name, schema=effective_schema)
+ return not self.checkfirst or not self.dialect.has_table(
+ self.connection, table.name, schema=effective_schema
+ )
def _can_create_sequence(self, sequence):
effective_schema = self.connection.schema_for_object(sequence)
- return self.dialect.supports_sequences and \
- (
- (not self.dialect.sequences_optional or
- not sequence.optional) and
- (
- not self.checkfirst or
- not self.dialect.has_sequence(
- self.connection,
- sequence.name,
- schema=effective_schema)
+ return self.dialect.supports_sequences and (
+ (not self.dialect.sequences_optional or not sequence.optional)
+ and (
+ not self.checkfirst
+ or not self.dialect.has_sequence(
+ self.connection, sequence.name, schema=effective_schema
)
)
+ )
def visit_metadata(self, metadata):
if self.tables is not None:
@@ -733,18 +747,23 @@ class SchemaGenerator(DDLBase):
tables = list(metadata.tables.values())
collection = sort_tables_and_constraints(
- [t for t in tables if self._can_create_table(t)])
-
- seq_coll = [s for s in metadata._sequences.values()
- if s.column is None and self._can_create_sequence(s)]
+ [t for t in tables if self._can_create_table(t)]
+ )
- event_collection = [
- t for (t, fks) in collection if t is not None
+ seq_coll = [
+ s
+ for s in metadata._sequences.values()
+ if s.column is None and self._can_create_sequence(s)
]
- metadata.dispatch.before_create(metadata, self.connection,
- tables=event_collection,
- checkfirst=self.checkfirst,
- _ddl_runner=self)
+
+ event_collection = [t for (t, fks) in collection if t is not None]
+ metadata.dispatch.before_create(
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
for seq in seq_coll:
self.traverse_single(seq, create_ok=True)
@@ -752,30 +771,40 @@ class SchemaGenerator(DDLBase):
for table, fkcs in collection:
if table is not None:
self.traverse_single(
- table, create_ok=True,
+ table,
+ create_ok=True,
include_foreign_key_constraints=fkcs,
- _is_metadata_operation=True)
+ _is_metadata_operation=True,
+ )
else:
for fkc in fkcs:
self.traverse_single(fkc)
- metadata.dispatch.after_create(metadata, self.connection,
- tables=event_collection,
- checkfirst=self.checkfirst,
- _ddl_runner=self)
+ metadata.dispatch.after_create(
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
def visit_table(
- self, table, create_ok=False,
- include_foreign_key_constraints=None,
- _is_metadata_operation=False):
+ self,
+ table,
+ create_ok=False,
+ include_foreign_key_constraints=None,
+ _is_metadata_operation=False,
+ ):
if not create_ok and not self._can_create_table(table):
return
table.dispatch.before_create(
- table, self.connection,
+ table,
+ self.connection,
checkfirst=self.checkfirst,
_ddl_runner=self,
- _is_metadata_operation=_is_metadata_operation)
+ _is_metadata_operation=_is_metadata_operation,
+ )
for column in table.columns:
if column.default is not None:
@@ -788,10 +817,11 @@ class SchemaGenerator(DDLBase):
self.connection.execute(
CreateTable(
table,
- include_foreign_key_constraints=include_foreign_key_constraints
- ))
+ include_foreign_key_constraints=include_foreign_key_constraints,
+ )
+ )
- if hasattr(table, 'indexes'):
+ if hasattr(table, "indexes"):
for index in table.indexes:
self.traverse_single(index)
@@ -804,10 +834,12 @@ class SchemaGenerator(DDLBase):
self.connection.execute(SetColumnComment(column))
table.dispatch.after_create(
- table, self.connection,
+ table,
+ self.connection,
checkfirst=self.checkfirst,
_ddl_runner=self,
- _is_metadata_operation=_is_metadata_operation)
+ _is_metadata_operation=_is_metadata_operation,
+ )
def visit_foreign_key_constraint(self, constraint):
if not self.dialect.supports_alter:
@@ -824,9 +856,9 @@ class SchemaGenerator(DDLBase):
class SchemaDropper(DDLBase):
-
- def __init__(self, dialect, connection, checkfirst=False,
- tables=None, **kwargs):
+ def __init__(
+ self, dialect, connection, checkfirst=False, tables=None, **kwargs
+ ):
super(SchemaDropper, self).__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables
@@ -842,15 +874,17 @@ class SchemaDropper(DDLBase):
try:
unsorted_tables = [t for t in tables if self._can_drop_table(t)]
- collection = list(reversed(
- sort_tables_and_constraints(
- unsorted_tables,
- filter_fn=lambda constraint: False
- if not self.dialect.supports_alter
- or constraint.name is None
- else None
+ collection = list(
+ reversed(
+ sort_tables_and_constraints(
+ unsorted_tables,
+ filter_fn=lambda constraint: False
+ if not self.dialect.supports_alter
+ or constraint.name is None
+ else None,
+ )
)
- ))
+ )
except exc.CircularDependencyError as err2:
if not self.dialect.supports_alter:
util.warn(
@@ -862,16 +896,15 @@ class SchemaDropper(DDLBase):
"ForeignKeyConstraint "
"objects involved in the cycle to mark these as known "
"cycles that will be ignored."
- % (
- ", ".join(sorted([t.fullname for t in err2.cycles]))
- )
+ % (", ".join(sorted([t.fullname for t in err2.cycles])))
)
collection = [(t, ()) for t in unsorted_tables]
else:
util.raise_from_cause(
exc.CircularDependencyError(
err2.args[0],
- err2.cycles, err2.edges,
+ err2.cycles,
+ err2.edges,
msg="Can't sort tables for DROP; an "
"unresolvable foreign key "
"dependency exists between tables: %s. Please ensure "
@@ -880,9 +913,10 @@ class SchemaDropper(DDLBase):
"names so that they can be dropped using "
"DROP CONSTRAINT."
% (
- ", ".join(sorted([t.fullname for t in err2.cycles]))
- )
-
+ ", ".join(
+ sorted([t.fullname for t in err2.cycles])
+ )
+ ),
)
)
@@ -892,18 +926,21 @@ class SchemaDropper(DDLBase):
if s.column is None and self._can_drop_sequence(s)
]
- event_collection = [
- t for (t, fks) in collection if t is not None
- ]
+ event_collection = [t for (t, fks) in collection if t is not None]
metadata.dispatch.before_drop(
- metadata, self.connection, tables=event_collection,
- checkfirst=self.checkfirst, _ddl_runner=self)
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
for table, fkcs in collection:
if table is not None:
self.traverse_single(
- table, drop_ok=True, _is_metadata_operation=True)
+ table, drop_ok=True, _is_metadata_operation=True
+ )
else:
for fkc in fkcs:
self.traverse_single(fkc)
@@ -912,8 +949,12 @@ class SchemaDropper(DDLBase):
self.traverse_single(seq, drop_ok=True)
metadata.dispatch.after_drop(
- metadata, self.connection, tables=event_collection,
- checkfirst=self.checkfirst, _ddl_runner=self)
+ metadata,
+ self.connection,
+ tables=event_collection,
+ checkfirst=self.checkfirst,
+ _ddl_runner=self,
+ )
def _can_drop_table(self, table):
self.dialect.validate_identifier(table.name)
@@ -921,19 +962,20 @@ class SchemaDropper(DDLBase):
if effective_schema:
self.dialect.validate_identifier(effective_schema)
return not self.checkfirst or self.dialect.has_table(
- self.connection, table.name, schema=effective_schema)
+ self.connection, table.name, schema=effective_schema
+ )
def _can_drop_sequence(self, sequence):
effective_schema = self.connection.schema_for_object(sequence)
- return self.dialect.supports_sequences and \
- ((not self.dialect.sequences_optional or
- not sequence.optional) and
- (not self.checkfirst or
- self.dialect.has_sequence(
- self.connection,
- sequence.name,
- schema=effective_schema))
- )
+ return self.dialect.supports_sequences and (
+ (not self.dialect.sequences_optional or not sequence.optional)
+ and (
+ not self.checkfirst
+ or self.dialect.has_sequence(
+ self.connection, sequence.name, schema=effective_schema
+ )
+ )
+ )
def visit_index(self, index):
self.connection.execute(DropIndex(index))
@@ -943,10 +985,12 @@ class SchemaDropper(DDLBase):
return
table.dispatch.before_drop(
- table, self.connection,
+ table,
+ self.connection,
checkfirst=self.checkfirst,
_ddl_runner=self,
- _is_metadata_operation=_is_metadata_operation)
+ _is_metadata_operation=_is_metadata_operation,
+ )
self.connection.execute(DropTable(table))
@@ -960,10 +1004,12 @@ class SchemaDropper(DDLBase):
self.traverse_single(column.default)
table.dispatch.after_drop(
- table, self.connection,
+ table,
+ self.connection,
checkfirst=self.checkfirst,
_ddl_runner=self,
- _is_metadata_operation=_is_metadata_operation)
+ _is_metadata_operation=_is_metadata_operation,
+ )
def visit_foreign_key_constraint(self, constraint):
if not self.dialect.supports_alter:
@@ -1019,25 +1065,29 @@ def sort_tables(tables, skip_fn=None, extra_dependencies=None):
"""
if skip_fn is not None:
+
def _skip_fn(fkc):
for fk in fkc.elements:
if skip_fn(fk):
return True
else:
return None
+
else:
_skip_fn = None
return [
- t for (t, fkcs) in
- sort_tables_and_constraints(
- tables, filter_fn=_skip_fn, extra_dependencies=extra_dependencies)
+ t
+ for (t, fkcs) in sort_tables_and_constraints(
+ tables, filter_fn=_skip_fn, extra_dependencies=extra_dependencies
+ )
if t is not None
]
def sort_tables_and_constraints(
- tables, filter_fn=None, extra_dependencies=None):
+ tables, filter_fn=None, extra_dependencies=None
+):
"""sort a collection of :class:`.Table` / :class:`.ForeignKeyConstraint`
objects.
@@ -1109,8 +1159,9 @@ def sort_tables_and_constraints(
try:
candidate_sort = list(
topological.sort(
- fixed_dependencies.union(mutable_dependencies), tables,
- deterministic_order=True
+ fixed_dependencies.union(mutable_dependencies),
+ tables,
+ deterministic_order=True,
)
)
except exc.CircularDependencyError as err:
@@ -1118,8 +1169,10 @@ def sort_tables_and_constraints(
if edge in mutable_dependencies:
table = edge[1]
can_remove = [
- fkc for fkc in table.foreign_key_constraints
- if filter_fn is None or filter_fn(fkc) is not False]
+ fkc
+ for fkc in table.foreign_key_constraints
+ if filter_fn is None or filter_fn(fkc) is not False
+ ]
remaining_fkcs.update(can_remove)
for fkc in can_remove:
dependent_on = fkc.referred_table
@@ -1127,8 +1180,9 @@ def sort_tables_and_constraints(
mutable_dependencies.discard((dependent_on, table))
candidate_sort = list(
topological.sort(
- fixed_dependencies.union(mutable_dependencies), tables,
- deterministic_order=True
+ fixed_dependencies.union(mutable_dependencies),
+ tables,
+ deterministic_order=True,
)
)
diff --git a/lib/sqlalchemy/sql/default_comparator.py b/lib/sqlalchemy/sql/default_comparator.py
index 8149f9731..fa0052198 100644
--- a/lib/sqlalchemy/sql/default_comparator.py
+++ b/lib/sqlalchemy/sql/default_comparator.py
@@ -11,19 +11,43 @@
from .. import exc, util
from . import type_api
from . import operators
-from .elements import BindParameter, True_, False_, BinaryExpression, \
- Null, _const_expr, _clause_element_as_expr, \
- ClauseList, ColumnElement, TextClause, UnaryExpression, \
- collate, _is_literal, _literal_as_text, ClauseElement, and_, or_, \
- Slice, Visitable, _literal_as_binds, CollectionAggregate, \
- Tuple
+from .elements import (
+ BindParameter,
+ True_,
+ False_,
+ BinaryExpression,
+ Null,
+ _const_expr,
+ _clause_element_as_expr,
+ ClauseList,
+ ColumnElement,
+ TextClause,
+ UnaryExpression,
+ collate,
+ _is_literal,
+ _literal_as_text,
+ ClauseElement,
+ and_,
+ or_,
+ Slice,
+ Visitable,
+ _literal_as_binds,
+ CollectionAggregate,
+ Tuple,
+)
from .selectable import SelectBase, Alias, Selectable, ScalarSelect
-def _boolean_compare(expr, op, obj, negate=None, reverse=False,
- _python_is_types=(util.NoneType, bool),
- result_type = None,
- **kwargs):
+def _boolean_compare(
+ expr,
+ op,
+ obj,
+ negate=None,
+ reverse=False,
+ _python_is_types=(util.NoneType, bool),
+ result_type=None,
+ **kwargs
+):
if result_type is None:
result_type = type_api.BOOLEANTYPE
@@ -33,57 +57,64 @@ def _boolean_compare(expr, op, obj, negate=None, reverse=False,
# allow x ==/!= True/False to be treated as a literal.
# this comes out to "== / != true/false" or "1/0" if those
# constants aren't supported and works on all platforms
- if op in (operators.eq, operators.ne) and \
- isinstance(obj, (bool, True_, False_)):
- return BinaryExpression(expr,
- _literal_as_text(obj),
- op,
- type_=result_type,
- negate=negate, modifiers=kwargs)
+ if op in (operators.eq, operators.ne) and isinstance(
+ obj, (bool, True_, False_)
+ ):
+ return BinaryExpression(
+ expr,
+ _literal_as_text(obj),
+ op,
+ type_=result_type,
+ negate=negate,
+ modifiers=kwargs,
+ )
elif op in (operators.is_distinct_from, operators.isnot_distinct_from):
- return BinaryExpression(expr,
- _literal_as_text(obj),
- op,
- type_=result_type,
- negate=negate, modifiers=kwargs)
+ return BinaryExpression(
+ expr,
+ _literal_as_text(obj),
+ op,
+ type_=result_type,
+ negate=negate,
+ modifiers=kwargs,
+ )
else:
# all other None/True/False uses IS, IS NOT
if op in (operators.eq, operators.is_):
- return BinaryExpression(expr, _const_expr(obj),
- operators.is_,
- negate=operators.isnot,
- type_=result_type
- )
+ return BinaryExpression(
+ expr,
+ _const_expr(obj),
+ operators.is_,
+ negate=operators.isnot,
+ type_=result_type,
+ )
elif op in (operators.ne, operators.isnot):
- return BinaryExpression(expr, _const_expr(obj),
- operators.isnot,
- negate=operators.is_,
- type_=result_type
- )
+ return BinaryExpression(
+ expr,
+ _const_expr(obj),
+ operators.isnot,
+ negate=operators.is_,
+ type_=result_type,
+ )
else:
raise exc.ArgumentError(
"Only '=', '!=', 'is_()', 'isnot()', "
"'is_distinct_from()', 'isnot_distinct_from()' "
- "operators can be used with None/True/False")
+ "operators can be used with None/True/False"
+ )
else:
obj = _check_literal(expr, op, obj)
if reverse:
- return BinaryExpression(obj,
- expr,
- op,
- type_=result_type,
- negate=negate, modifiers=kwargs)
+ return BinaryExpression(
+ obj, expr, op, type_=result_type, negate=negate, modifiers=kwargs
+ )
else:
- return BinaryExpression(expr,
- obj,
- op,
- type_=result_type,
- negate=negate, modifiers=kwargs)
+ return BinaryExpression(
+ expr, obj, op, type_=result_type, negate=negate, modifiers=kwargs
+ )
-def _custom_op_operate(expr, op, obj, reverse=False, result_type=None,
- **kw):
+def _custom_op_operate(expr, op, obj, reverse=False, result_type=None, **kw):
if result_type is None:
if op.return_type:
result_type = op.return_type
@@ -91,11 +122,11 @@ def _custom_op_operate(expr, op, obj, reverse=False, result_type=None,
result_type = type_api.BOOLEANTYPE
return _binary_operate(
- expr, op, obj, reverse=reverse, result_type=result_type, **kw)
+ expr, op, obj, reverse=reverse, result_type=result_type, **kw
+ )
-def _binary_operate(expr, op, obj, reverse=False, result_type=None,
- **kw):
+def _binary_operate(expr, op, obj, reverse=False, result_type=None, **kw):
obj = _check_literal(expr, op, obj)
if reverse:
@@ -105,10 +136,10 @@ def _binary_operate(expr, op, obj, reverse=False, result_type=None,
if result_type is None:
op, result_type = left.comparator._adapt_expression(
- op, right.comparator)
+ op, right.comparator
+ )
- return BinaryExpression(
- left, right, op, type_=result_type, modifiers=kw)
+ return BinaryExpression(left, right, op, type_=result_type, modifiers=kw)
def _conjunction_operate(expr, op, other, **kw):
@@ -128,8 +159,7 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
seq_or_selectable = _clause_element_as_expr(seq_or_selectable)
if isinstance(seq_or_selectable, ScalarSelect):
- return _boolean_compare(expr, op, seq_or_selectable,
- negate=negate_op)
+ return _boolean_compare(expr, op, seq_or_selectable, negate=negate_op)
elif isinstance(seq_or_selectable, SelectBase):
# TODO: if we ever want to support (x, y, z) IN (select x,
@@ -138,32 +168,33 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
# does not export itself as a FROM clause
return _boolean_compare(
- expr, op, seq_or_selectable.as_scalar(),
- negate=negate_op, **kw)
+ expr, op, seq_or_selectable.as_scalar(), negate=negate_op, **kw
+ )
elif isinstance(seq_or_selectable, (Selectable, TextClause)):
- return _boolean_compare(expr, op, seq_or_selectable,
- negate=negate_op, **kw)
+ return _boolean_compare(
+ expr, op, seq_or_selectable, negate=negate_op, **kw
+ )
elif isinstance(seq_or_selectable, ClauseElement):
- if isinstance(seq_or_selectable, BindParameter) and \
- seq_or_selectable.expanding:
+ if (
+ isinstance(seq_or_selectable, BindParameter)
+ and seq_or_selectable.expanding
+ ):
if isinstance(expr, Tuple):
- seq_or_selectable = (
- seq_or_selectable._with_expanding_in_types(
- [elem.type for elem in expr]
- )
+ seq_or_selectable = seq_or_selectable._with_expanding_in_types(
+ [elem.type for elem in expr]
)
return _boolean_compare(
- expr, op,
- seq_or_selectable,
- negate=negate_op)
+ expr, op, seq_or_selectable, negate=negate_op
+ )
else:
raise exc.InvalidRequestError(
- 'in_() accepts'
- ' either a list of expressions, '
+ "in_() accepts"
+ " either a list of expressions, "
'a selectable, or an "expanding" bound parameter: %r'
- % seq_or_selectable)
+ % seq_or_selectable
+ )
# Handle non selectable arguments as sequences
args = []
@@ -171,9 +202,10 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
if not _is_literal(o):
if not isinstance(o, operators.ColumnOperators):
raise exc.InvalidRequestError(
- 'in_() accepts'
- ' either a list of expressions, '
- 'a selectable, or an "expanding" bound parameter: %r' % o)
+ "in_() accepts"
+ " either a list of expressions, "
+ 'a selectable, or an "expanding" bound parameter: %r' % o
+ )
elif o is None:
o = Null()
else:
@@ -182,15 +214,14 @@ def _in_impl(expr, op, seq_or_selectable, negate_op, **kw):
if len(args) == 0:
op, negate_op = (
- operators.empty_in_op,
- operators.empty_notin_op) if op is operators.in_op \
- else (
- operators.empty_notin_op,
- operators.empty_in_op)
+ (operators.empty_in_op, operators.empty_notin_op)
+ if op is operators.in_op
+ else (operators.empty_notin_op, operators.empty_in_op)
+ )
- return _boolean_compare(expr, op,
- ClauseList(*args).self_group(against=op),
- negate=negate_op)
+ return _boolean_compare(
+ expr, op, ClauseList(*args).self_group(against=op), negate=negate_op
+ )
def _getitem_impl(expr, op, other, **kw):
@@ -202,13 +233,14 @@ def _getitem_impl(expr, op, other, **kw):
def _unsupported_impl(expr, op, *arg, **kw):
- raise NotImplementedError("Operator '%s' is not supported on "
- "this expression" % op.__name__)
+ raise NotImplementedError(
+ "Operator '%s' is not supported on " "this expression" % op.__name__
+ )
def _inv_impl(expr, op, **kw):
"""See :meth:`.ColumnOperators.__inv__`."""
- if hasattr(expr, 'negation_clause'):
+ if hasattr(expr, "negation_clause"):
return expr.negation_clause
else:
return expr._negate()
@@ -223,20 +255,22 @@ def _match_impl(expr, op, other, **kw):
"""See :meth:`.ColumnOperators.match`."""
return _boolean_compare(
- expr, operators.match_op,
- _check_literal(
- expr, operators.match_op, other),
+ expr,
+ operators.match_op,
+ _check_literal(expr, operators.match_op, other),
result_type=type_api.MATCHTYPE,
negate=operators.notmatch_op
- if op is operators.match_op else operators.match_op,
+ if op is operators.match_op
+ else operators.match_op,
**kw
)
def _distinct_impl(expr, op, **kw):
"""See :meth:`.ColumnOperators.distinct`."""
- return UnaryExpression(expr, operator=operators.distinct_op,
- type_=expr.type)
+ return UnaryExpression(
+ expr, operator=operators.distinct_op, type_=expr.type
+ )
def _between_impl(expr, op, cleft, cright, **kw):
@@ -247,17 +281,21 @@ def _between_impl(expr, op, cleft, cright, **kw):
_check_literal(expr, operators.and_, cleft),
_check_literal(expr, operators.and_, cright),
operator=operators.and_,
- group=False, group_contents=False),
+ group=False,
+ group_contents=False,
+ ),
op,
negate=operators.notbetween_op
if op is operators.between_op
else operators.between_op,
- modifiers=kw)
+ modifiers=kw,
+ )
def _collate_impl(expr, op, other, **kw):
return collate(expr, other)
+
# a mapping of operators with the method they use, along with
# their negated operator for comparison operators
operator_lookup = {
@@ -271,8 +309,8 @@ operator_lookup = {
"mod": (_binary_operate,),
"truediv": (_binary_operate,),
"custom_op": (_custom_op_operate,),
- "json_path_getitem_op": (_binary_operate, ),
- "json_getitem_op": (_binary_operate, ),
+ "json_path_getitem_op": (_binary_operate,),
+ "json_getitem_op": (_binary_operate,),
"concat_op": (_binary_operate,),
"any_op": (_scalar, CollectionAggregate._create_any),
"all_op": (_scalar, CollectionAggregate._create_all),
@@ -303,8 +341,8 @@ operator_lookup = {
"match_op": (_match_impl,),
"notmatch_op": (_match_impl,),
"distinct_op": (_distinct_impl,),
- "between_op": (_between_impl, ),
- "notbetween_op": (_between_impl, ),
+ "between_op": (_between_impl,),
+ "notbetween_op": (_between_impl,),
"neg": (_neg_impl,),
"getitem": (_getitem_impl,),
"lshift": (_unsupported_impl,),
@@ -315,12 +353,11 @@ operator_lookup = {
def _check_literal(expr, operator, other, bindparam_type=None):
if isinstance(other, (ColumnElement, TextClause)):
- if isinstance(other, BindParameter) and \
- other.type._isnull:
+ if isinstance(other, BindParameter) and other.type._isnull:
other = other._clone()
other.type = expr.type
return other
- elif hasattr(other, '__clause_element__'):
+ elif hasattr(other, "__clause_element__"):
other = other.__clause_element__()
elif isinstance(other, type_api.TypeEngine.Comparator):
other = other.expr
@@ -331,4 +368,3 @@ def _check_literal(expr, operator, other, bindparam_type=None):
return expr._bind_param(operator, other, type_=bindparam_type)
else:
return other
-
diff --git a/lib/sqlalchemy/sql/dml.py b/lib/sqlalchemy/sql/dml.py
index d6890de15..0cea5ccc4 100644
--- a/lib/sqlalchemy/sql/dml.py
+++ b/lib/sqlalchemy/sql/dml.py
@@ -9,26 +9,43 @@ Provide :class:`.Insert`, :class:`.Update` and :class:`.Delete`.
"""
-from .base import Executable, _generative, _from_objects, DialectKWArgs, \
- ColumnCollection
-from .elements import ClauseElement, _literal_as_text, Null, and_, _clone, \
- _column_as_key
-from .selectable import _interpret_as_from, _interpret_as_select, \
- HasPrefixes, HasCTE
+from .base import (
+ Executable,
+ _generative,
+ _from_objects,
+ DialectKWArgs,
+ ColumnCollection,
+)
+from .elements import (
+ ClauseElement,
+ _literal_as_text,
+ Null,
+ and_,
+ _clone,
+ _column_as_key,
+)
+from .selectable import (
+ _interpret_as_from,
+ _interpret_as_select,
+ HasPrefixes,
+ HasCTE,
+)
from .. import util
from .. import exc
class UpdateBase(
- HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement):
+ HasCTE, DialectKWArgs, HasPrefixes, Executable, ClauseElement
+):
"""Form the base for ``INSERT``, ``UPDATE``, and ``DELETE`` statements.
"""
- __visit_name__ = 'update_base'
+ __visit_name__ = "update_base"
- _execution_options = \
- Executable._execution_options.union({'autocommit': True})
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": True}
+ )
_hints = util.immutabledict()
_parameter_ordering = None
_prefixes = ()
@@ -37,30 +54,33 @@ class UpdateBase(
def _process_colparams(self, parameters):
def process_single(p):
if isinstance(p, (list, tuple)):
- return dict(
- (c.key, pval)
- for c, pval in zip(self.table.c, p)
- )
+ return dict((c.key, pval) for c, pval in zip(self.table.c, p))
else:
return p
if self._preserve_parameter_order and parameters is not None:
- if not isinstance(parameters, list) or \
- (parameters and not isinstance(parameters[0], tuple)):
+ if not isinstance(parameters, list) or (
+ parameters and not isinstance(parameters[0], tuple)
+ ):
raise ValueError(
"When preserve_parameter_order is True, "
- "values() only accepts a list of 2-tuples")
+ "values() only accepts a list of 2-tuples"
+ )
self._parameter_ordering = [key for key, value in parameters]
return dict(parameters), False
- if (isinstance(parameters, (list, tuple)) and parameters and
- isinstance(parameters[0], (list, tuple, dict))):
+ if (
+ isinstance(parameters, (list, tuple))
+ and parameters
+ and isinstance(parameters[0], (list, tuple, dict))
+ ):
if not self._supports_multi_parameters:
raise exc.InvalidRequestError(
"This construct does not support "
- "multiple parameter sets.")
+ "multiple parameter sets."
+ )
return [process_single(p) for p in parameters], True
else:
@@ -77,7 +97,8 @@ class UpdateBase(
raise NotImplementedError(
"params() is not supported for INSERT/UPDATE/DELETE statements."
" To set the values for an INSERT or UPDATE statement, use"
- " stmt.values(**parameters).")
+ " stmt.values(**parameters)."
+ )
def bind(self):
"""Return a 'bind' linked to this :class:`.UpdateBase`
@@ -88,6 +109,7 @@ class UpdateBase(
def _set_bind(self, bind):
self._bind = bind
+
bind = property(bind, _set_bind)
@_generative
@@ -181,15 +203,14 @@ class UpdateBase(
if selectable is None:
selectable = self.table
- self._hints = self._hints.union(
- {(selectable, dialect_name): text})
+ self._hints = self._hints.union({(selectable, dialect_name): text})
class ValuesBase(UpdateBase):
"""Supplies support for :meth:`.ValuesBase.values` to
INSERT and UPDATE constructs."""
- __visit_name__ = 'values_base'
+ __visit_name__ = "values_base"
_supports_multi_parameters = False
_has_multi_parameters = False
@@ -199,8 +220,9 @@ class ValuesBase(UpdateBase):
def __init__(self, table, values, prefixes):
self.table = _interpret_as_from(table)
- self.parameters, self._has_multi_parameters = \
- self._process_colparams(values)
+ self.parameters, self._has_multi_parameters = self._process_colparams(
+ values
+ )
if prefixes:
self._setup_prefixes(prefixes)
@@ -332,23 +354,27 @@ class ValuesBase(UpdateBase):
"""
if self.select is not None:
raise exc.InvalidRequestError(
- "This construct already inserts from a SELECT")
+ "This construct already inserts from a SELECT"
+ )
if self._has_multi_parameters and kwargs:
raise exc.InvalidRequestError(
- "This construct already has multiple parameter sets.")
+ "This construct already has multiple parameter sets."
+ )
if args:
if len(args) > 1:
raise exc.ArgumentError(
"Only a single dictionary/tuple or list of "
- "dictionaries/tuples is accepted positionally.")
+ "dictionaries/tuples is accepted positionally."
+ )
v = args[0]
else:
v = {}
if self.parameters is None:
- self.parameters, self._has_multi_parameters = \
- self._process_colparams(v)
+ self.parameters, self._has_multi_parameters = self._process_colparams(
+ v
+ )
else:
if self._has_multi_parameters:
self.parameters = list(self.parameters)
@@ -356,7 +382,8 @@ class ValuesBase(UpdateBase):
if not self._has_multi_parameters:
raise exc.ArgumentError(
"Can't mix single-values and multiple values "
- "formats in one statement")
+ "formats in one statement"
+ )
self.parameters.extend(p)
else:
@@ -365,14 +392,16 @@ class ValuesBase(UpdateBase):
if self._has_multi_parameters:
raise exc.ArgumentError(
"Can't mix single-values and multiple values "
- "formats in one statement")
+ "formats in one statement"
+ )
self.parameters.update(p)
if kwargs:
if self._has_multi_parameters:
raise exc.ArgumentError(
"Can't pass kwargs and multiple parameter sets "
- "simultaneously")
+ "simultaneously"
+ )
else:
self.parameters.update(kwargs)
@@ -456,19 +485,22 @@ class Insert(ValuesBase):
:ref:`coretutorial_insert_expressions`
"""
- __visit_name__ = 'insert'
+
+ __visit_name__ = "insert"
_supports_multi_parameters = True
- def __init__(self,
- table,
- values=None,
- inline=False,
- bind=None,
- prefixes=None,
- returning=None,
- return_defaults=False,
- **dialect_kw):
+ def __init__(
+ self,
+ table,
+ values=None,
+ inline=False,
+ bind=None,
+ prefixes=None,
+ returning=None,
+ return_defaults=False,
+ **dialect_kw
+ ):
"""Construct an :class:`.Insert` object.
Similar functionality is available via the
@@ -526,7 +558,7 @@ class Insert(ValuesBase):
def get_children(self, **kwargs):
if self.select is not None:
- return self.select,
+ return (self.select,)
else:
return ()
@@ -578,11 +610,12 @@ class Insert(ValuesBase):
"""
if self.parameters:
raise exc.InvalidRequestError(
- "This construct already inserts value expressions")
+ "This construct already inserts value expressions"
+ )
- self.parameters, self._has_multi_parameters = \
- self._process_colparams(
- {_column_as_key(n): Null() for n in names})
+ self.parameters, self._has_multi_parameters = self._process_colparams(
+ {_column_as_key(n): Null() for n in names}
+ )
self.select_names = names
self.inline = True
@@ -603,19 +636,22 @@ class Update(ValuesBase):
function.
"""
- __visit_name__ = 'update'
-
- def __init__(self,
- table,
- whereclause=None,
- values=None,
- inline=False,
- bind=None,
- prefixes=None,
- returning=None,
- return_defaults=False,
- preserve_parameter_order=False,
- **dialect_kw):
+
+ __visit_name__ = "update"
+
+ def __init__(
+ self,
+ table,
+ whereclause=None,
+ values=None,
+ inline=False,
+ bind=None,
+ prefixes=None,
+ returning=None,
+ return_defaults=False,
+ preserve_parameter_order=False,
+ **dialect_kw
+ ):
r"""Construct an :class:`.Update` object.
E.g.::
@@ -745,7 +781,7 @@ class Update(ValuesBase):
def get_children(self, **kwargs):
if self._whereclause is not None:
- return self._whereclause,
+ return (self._whereclause,)
else:
return ()
@@ -761,8 +797,9 @@ class Update(ValuesBase):
"""
if self._whereclause is not None:
- self._whereclause = and_(self._whereclause,
- _literal_as_text(whereclause))
+ self._whereclause = and_(
+ self._whereclause, _literal_as_text(whereclause)
+ )
else:
self._whereclause = _literal_as_text(whereclause)
@@ -788,15 +825,17 @@ class Delete(UpdateBase):
"""
- __visit_name__ = 'delete'
-
- def __init__(self,
- table,
- whereclause=None,
- bind=None,
- returning=None,
- prefixes=None,
- **dialect_kw):
+ __visit_name__ = "delete"
+
+ def __init__(
+ self,
+ table,
+ whereclause=None,
+ bind=None,
+ returning=None,
+ prefixes=None,
+ **dialect_kw
+ ):
"""Construct :class:`.Delete` object.
Similar functionality is available via the
@@ -847,7 +886,7 @@ class Delete(UpdateBase):
def get_children(self, **kwargs):
if self._whereclause is not None:
- return self._whereclause,
+ return (self._whereclause,)
else:
return ()
@@ -856,8 +895,9 @@ class Delete(UpdateBase):
"""Add the given WHERE clause to a newly returned delete construct."""
if self._whereclause is not None:
- self._whereclause = and_(self._whereclause,
- _literal_as_text(whereclause))
+ self._whereclause = and_(
+ self._whereclause, _literal_as_text(whereclause)
+ )
else:
self._whereclause = _literal_as_text(whereclause)
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index de3b7992a..e857f2da8 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -51,9 +51,8 @@ def collate(expression, collation):
expr = _literal_as_binds(expression)
return BinaryExpression(
- expr,
- CollationClause(collation),
- operators.collate, type_=expr.type)
+ expr, CollationClause(collation), operators.collate, type_=expr.type
+ )
def between(expr, lower_bound, upper_bound, symmetric=False):
@@ -130,8 +129,6 @@ def literal(value, type_=None):
return BindParameter(None, value, type_=type_, unique=True)
-
-
def outparam(key, type_=None):
"""Create an 'OUT' parameter for usage in functions (stored procedures),
for databases which support them.
@@ -142,8 +139,7 @@ def outparam(key, type_=None):
attribute, which returns a dictionary containing the values.
"""
- return BindParameter(
- key, None, type_=type_, unique=False, isoutparam=True)
+ return BindParameter(key, None, type_=type_, unique=False, isoutparam=True)
def not_(clause):
@@ -163,7 +159,8 @@ class ClauseElement(Visitable):
expression.
"""
- __visit_name__ = 'clause'
+
+ __visit_name__ = "clause"
_annotations = {}
supports_execution = False
@@ -230,7 +227,7 @@ class ClauseElement(Visitable):
def __getstate__(self):
d = self.__dict__.copy()
- d.pop('_is_clone_of', None)
+ d.pop("_is_clone_of", None)
return d
def _annotate(self, values):
@@ -300,7 +297,8 @@ class ClauseElement(Visitable):
kwargs.update(optionaldict[0])
elif len(optionaldict) > 1:
raise exc.ArgumentError(
- "params() takes zero or one positional dictionary argument")
+ "params() takes zero or one positional dictionary argument"
+ )
def visit_bindparam(bind):
if bind.key in kwargs:
@@ -308,7 +306,8 @@ class ClauseElement(Visitable):
bind.required = False
if unique:
bind._convert_to_unique()
- return cloned_traverse(self, {}, {'bindparam': visit_bindparam})
+
+ return cloned_traverse(self, {}, {"bindparam": visit_bindparam})
def compare(self, other, **kw):
r"""Compare this ClauseElement to the given ClauseElement.
@@ -451,7 +450,7 @@ class ClauseElement(Visitable):
if util.py3k:
return str(self.compile())
else:
- return unicode(self.compile()).encode('ascii', 'backslashreplace')
+ return unicode(self.compile()).encode("ascii", "backslashreplace")
def __and__(self, other):
"""'and' at the ClauseElement level.
@@ -472,7 +471,7 @@ class ClauseElement(Visitable):
return or_(self, other)
def __invert__(self):
- if hasattr(self, 'negation_clause'):
+ if hasattr(self, "negation_clause"):
return self.negation_clause
else:
return self._negate()
@@ -481,7 +480,8 @@ class ClauseElement(Visitable):
return UnaryExpression(
self.self_group(against=operators.inv),
operator=operators.inv,
- negate=None)
+ negate=None,
+ )
def __bool__(self):
raise TypeError("Boolean value of this clause is not defined")
@@ -493,8 +493,12 @@ class ClauseElement(Visitable):
if friendly is None:
return object.__repr__(self)
else:
- return '<%s.%s at 0x%x; %s>' % (
- self.__module__, self.__class__.__name__, id(self), friendly)
+ return "<%s.%s at 0x%x; %s>" % (
+ self.__module__,
+ self.__class__.__name__,
+ id(self),
+ friendly,
+ )
class ColumnElement(operators.ColumnOperators, ClauseElement):
@@ -571,7 +575,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
"""
- __visit_name__ = 'column_element'
+ __visit_name__ = "column_element"
primary_key = False
foreign_keys = []
@@ -646,11 +650,12 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
_alt_names = ()
def self_group(self, against=None):
- if (against in (operators.and_, operators.or_, operators._asbool) and
- self.type._type_affinity
- is type_api.BOOLEANTYPE._type_affinity):
+ if (
+ against in (operators.and_, operators.or_, operators._asbool)
+ and self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity
+ ):
return AsBoolean(self, operators.istrue, operators.isfalse)
- elif (against in (operators.any_op, operators.all_op)):
+ elif against in (operators.any_op, operators.all_op):
return Grouping(self)
else:
return self
@@ -675,7 +680,8 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
except AttributeError:
raise TypeError(
"Object %r associated with '.type' attribute "
- "is not a TypeEngine class or object" % self.type)
+ "is not a TypeEngine class or object" % self.type
+ )
else:
return comparator_factory(self)
@@ -684,10 +690,8 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
return getattr(self.comparator, key)
except AttributeError:
raise AttributeError(
- 'Neither %r object nor %r object has an attribute %r' % (
- type(self).__name__,
- type(self.comparator).__name__,
- key)
+ "Neither %r object nor %r object has an attribute %r"
+ % (type(self).__name__, type(self.comparator).__name__, key)
)
def operate(self, op, *other, **kwargs):
@@ -697,10 +701,14 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
return op(other, self.comparator, **kwargs)
def _bind_param(self, operator, obj, type_=None):
- return BindParameter(None, obj,
- _compared_to_operator=operator,
- type_=type_,
- _compared_to_type=self.type, unique=True)
+ return BindParameter(
+ None,
+ obj,
+ _compared_to_operator=operator,
+ type_=type_,
+ _compared_to_type=self.type,
+ unique=True,
+ )
@property
def expression(self):
@@ -713,17 +721,18 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
@property
def _select_iterable(self):
- return (self, )
+ return (self,)
@util.memoized_property
def base_columns(self):
- return util.column_set(c for c in self.proxy_set
- if not hasattr(c, '_proxies'))
+ return util.column_set(
+ c for c in self.proxy_set if not hasattr(c, "_proxies")
+ )
@util.memoized_property
def proxy_set(self):
s = util.column_set([self])
- if hasattr(self, '_proxies'):
+ if hasattr(self, "_proxies"):
for c in self._proxies:
s.update(c.proxy_set)
return s
@@ -738,11 +747,15 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
"""Return True if the given column element compares to this one
when targeting within a result row."""
- return hasattr(other, 'name') and hasattr(self, 'name') and \
- other.name == self.name
+ return (
+ hasattr(other, "name")
+ and hasattr(self, "name")
+ and other.name == self.name
+ )
def _make_proxy(
- self, selectable, name=None, name_is_truncatable=False, **kw):
+ self, selectable, name=None, name_is_truncatable=False, **kw
+ ):
"""Create a new :class:`.ColumnElement` representing this
:class:`.ColumnElement` as it appears in the select list of a
descending selectable.
@@ -762,13 +775,12 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
key = name
co = ColumnClause(
_as_truncated(name) if name_is_truncatable else name,
- type_=getattr(self, 'type', None),
- _selectable=selectable
+ type_=getattr(self, "type", None),
+ _selectable=selectable,
)
co._proxies = [self]
if selectable._is_clone_of is not None:
- co._is_clone_of = \
- selectable._is_clone_of.columns.get(key)
+ co._is_clone_of = selectable._is_clone_of.columns.get(key)
selectable._columns[key] = co
return co
@@ -788,7 +800,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
this one via foreign key or other criterion.
"""
- to_compare = (other, )
+ to_compare = (other,)
if equivalents and other in equivalents:
to_compare = equivalents[other].union(to_compare)
@@ -838,7 +850,7 @@ class ColumnElement(operators.ColumnOperators, ClauseElement):
self = self._is_clone_of
return _anonymous_label(
- '%%(%d %s)s' % (id(self), getattr(self, 'name', 'anon'))
+ "%%(%d %s)s" % (id(self), getattr(self, "name", "anon"))
)
@@ -862,18 +874,25 @@ class BindParameter(ColumnElement):
"""
- __visit_name__ = 'bindparam'
+ __visit_name__ = "bindparam"
_is_crud = False
_expanding_in_types = ()
- def __init__(self, key, value=NO_ARG, type_=None,
- unique=False, required=NO_ARG,
- quote=None, callable_=None,
- expanding=False,
- isoutparam=False,
- _compared_to_operator=None,
- _compared_to_type=None):
+ def __init__(
+ self,
+ key,
+ value=NO_ARG,
+ type_=None,
+ unique=False,
+ required=NO_ARG,
+ quote=None,
+ callable_=None,
+ expanding=False,
+ isoutparam=False,
+ _compared_to_operator=None,
+ _compared_to_type=None,
+ ):
r"""Produce a "bound expression".
The return value is an instance of :class:`.BindParameter`; this
@@ -1093,7 +1112,7 @@ class BindParameter(ColumnElement):
type_ = key.type
key = key.key
if required is NO_ARG:
- required = (value is NO_ARG and callable_ is None)
+ required = value is NO_ARG and callable_ is None
if value is NO_ARG:
value = None
@@ -1101,11 +1120,11 @@ class BindParameter(ColumnElement):
key = quoted_name(key, quote)
if unique:
- self.key = _anonymous_label('%%(%d %s)s' % (id(self), key
- or 'param'))
+ self.key = _anonymous_label(
+ "%%(%d %s)s" % (id(self), key or "param")
+ )
else:
- self.key = key or _anonymous_label('%%(%d param)s'
- % id(self))
+ self.key = key or _anonymous_label("%%(%d param)s" % id(self))
# identifying key that won't change across
# clones, used to identify the bind's logical
@@ -1114,7 +1133,7 @@ class BindParameter(ColumnElement):
# key that was passed in the first place, used to
# generate new keys
- self._orig_key = key or 'param'
+ self._orig_key = key or "param"
self.unique = unique
self.value = value
@@ -1125,9 +1144,9 @@ class BindParameter(ColumnElement):
if type_ is None:
if _compared_to_type is not None:
- self.type = \
- _compared_to_type.coerce_compared_value(
- _compared_to_operator, value)
+ self.type = _compared_to_type.coerce_compared_value(
+ _compared_to_operator, value
+ )
else:
self.type = type_api._resolve_value_to_type(value)
elif isinstance(type_, type):
@@ -1174,24 +1193,28 @@ class BindParameter(ColumnElement):
def _clone(self):
c = ClauseElement._clone(self)
if self.unique:
- c.key = _anonymous_label('%%(%d %s)s' % (id(c), c._orig_key
- or 'param'))
+ c.key = _anonymous_label(
+ "%%(%d %s)s" % (id(c), c._orig_key or "param")
+ )
return c
def _convert_to_unique(self):
if not self.unique:
self.unique = True
self.key = _anonymous_label(
- '%%(%d %s)s' % (id(self), self._orig_key or 'param'))
+ "%%(%d %s)s" % (id(self), self._orig_key or "param")
+ )
def compare(self, other, **kw):
"""Compare this :class:`BindParameter` to the given
clause."""
- return isinstance(other, BindParameter) \
- and self.type._compare_type_affinity(other.type) \
- and self.value == other.value \
+ return (
+ isinstance(other, BindParameter)
+ and self.type._compare_type_affinity(other.type)
+ and self.value == other.value
and self.callable == other.callable
+ )
def __getstate__(self):
"""execute a deferred value for serialization purposes."""
@@ -1200,13 +1223,16 @@ class BindParameter(ColumnElement):
v = self.value
if self.callable:
v = self.callable()
- d['callable'] = None
- d['value'] = v
+ d["callable"] = None
+ d["value"] = v
return d
def __repr__(self):
- return 'BindParameter(%r, %r, type_=%r)' % (self.key,
- self.value, self.type)
+ return "BindParameter(%r, %r, type_=%r)" % (
+ self.key,
+ self.value,
+ self.type,
+ )
class TypeClause(ClauseElement):
@@ -1216,7 +1242,7 @@ class TypeClause(ClauseElement):
"""
- __visit_name__ = 'typeclause'
+ __visit_name__ = "typeclause"
def __init__(self, type):
self.type = type
@@ -1242,12 +1268,12 @@ class TextClause(Executable, ClauseElement):
"""
- __visit_name__ = 'textclause'
+ __visit_name__ = "textclause"
- _bind_params_regex = re.compile(r'(?<![:\w\x5c]):(\w+)(?!:)', re.UNICODE)
- _execution_options = \
- Executable._execution_options.union(
- {'autocommit': PARSE_AUTOCOMMIT})
+ _bind_params_regex = re.compile(r"(?<![:\w\x5c]):(\w+)(?!:)", re.UNICODE)
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": PARSE_AUTOCOMMIT}
+ )
_is_implicitly_boolean = False
@property
@@ -1268,24 +1294,22 @@ class TextClause(Executable, ClauseElement):
_allow_label_resolve = False
- def __init__(
- self,
- text,
- bind=None):
+ def __init__(self, text, bind=None):
self._bind = bind
self._bindparams = {}
def repl(m):
self._bindparams[m.group(1)] = BindParameter(m.group(1))
- return ':%s' % m.group(1)
+ return ":%s" % m.group(1)
# scan the string and search for bind parameter names, add them
# to the list of bindparams
self.text = self._bind_params_regex.sub(repl, text)
@classmethod
- def _create_text(self, text, bind=None, bindparams=None,
- typemap=None, autocommit=None):
+ def _create_text(
+ self, text, bind=None, bindparams=None, typemap=None, autocommit=None
+ ):
r"""Construct a new :class:`.TextClause` clause, representing
a textual SQL string directly.
@@ -1428,8 +1452,10 @@ class TextClause(Executable, ClauseElement):
if typemap:
stmt = stmt.columns(**typemap)
if autocommit is not None:
- util.warn_deprecated('autocommit on text() is deprecated. '
- 'Use .execution_options(autocommit=True)')
+ util.warn_deprecated(
+ "autocommit on text() is deprecated. "
+ "Use .execution_options(autocommit=True)"
+ )
stmt = stmt.execution_options(autocommit=autocommit)
return stmt
@@ -1513,7 +1539,8 @@ class TextClause(Executable, ClauseElement):
except KeyError:
raise exc.ArgumentError(
"This text() construct doesn't define a "
- "bound parameter named %r" % bind.key)
+ "bound parameter named %r" % bind.key
+ )
else:
new_params[existing.key] = bind
@@ -1523,11 +1550,12 @@ class TextClause(Executable, ClauseElement):
except KeyError:
raise exc.ArgumentError(
"This text() construct doesn't define a "
- "bound parameter named %r" % key)
+ "bound parameter named %r" % key
+ )
else:
new_params[key] = existing._with_value(value)
- @util.dependencies('sqlalchemy.sql.selectable')
+ @util.dependencies("sqlalchemy.sql.selectable")
def columns(self, selectable, *cols, **types):
"""Turn this :class:`.TextClause` object into a :class:`.TextAsFrom`
object that can be embedded into another statement.
@@ -1629,12 +1657,14 @@ class TextClause(Executable, ClauseElement):
for col in cols
]
keyed_input_cols = [
- ColumnClause(key, type_) for key, type_ in types.items()]
+ ColumnClause(key, type_) for key, type_ in types.items()
+ ]
return selectable.TextAsFrom(
self,
positional_input_cols + keyed_input_cols,
- positional=bool(positional_input_cols) and not keyed_input_cols)
+ positional=bool(positional_input_cols) and not keyed_input_cols,
+ )
@property
def type(self):
@@ -1651,8 +1681,9 @@ class TextClause(Executable, ClauseElement):
return self
def _copy_internals(self, clone=_clone, **kw):
- self._bindparams = dict((b.key, clone(b, **kw))
- for b in self._bindparams.values())
+ self._bindparams = dict(
+ (b.key, clone(b, **kw)) for b in self._bindparams.values()
+ )
def get_children(self, **kwargs):
return list(self._bindparams.values())
@@ -1669,7 +1700,7 @@ class Null(ColumnElement):
"""
- __visit_name__ = 'null'
+ __visit_name__ = "null"
@util.memoized_property
def type(self):
@@ -1693,7 +1724,7 @@ class False_(ColumnElement):
"""
- __visit_name__ = 'false'
+ __visit_name__ = "false"
@util.memoized_property
def type(self):
@@ -1752,7 +1783,7 @@ class True_(ColumnElement):
"""
- __visit_name__ = 'true'
+ __visit_name__ = "true"
@util.memoized_property
def type(self):
@@ -1816,23 +1847,23 @@ class ClauseList(ClauseElement):
By default, is comma-separated, such as a column listing.
"""
- __visit_name__ = 'clauselist'
+
+ __visit_name__ = "clauselist"
def __init__(self, *clauses, **kwargs):
- self.operator = kwargs.pop('operator', operators.comma_op)
- self.group = kwargs.pop('group', True)
- self.group_contents = kwargs.pop('group_contents', True)
+ self.operator = kwargs.pop("operator", operators.comma_op)
+ self.group = kwargs.pop("group", True)
+ self.group_contents = kwargs.pop("group_contents", True)
text_converter = kwargs.pop(
- '_literal_as_text',
- _expression_literal_as_text)
+ "_literal_as_text", _expression_literal_as_text
+ )
if self.group_contents:
self.clauses = [
text_converter(clause).self_group(against=self.operator)
- for clause in clauses]
+ for clause in clauses
+ ]
else:
- self.clauses = [
- text_converter(clause)
- for clause in clauses]
+ self.clauses = [text_converter(clause) for clause in clauses]
self._is_implicitly_boolean = operators.is_boolean(self.operator)
def __iter__(self):
@@ -1847,8 +1878,9 @@ class ClauseList(ClauseElement):
def append(self, clause):
if self.group_contents:
- self.clauses.append(_literal_as_text(clause).
- self_group(against=self.operator))
+ self.clauses.append(
+ _literal_as_text(clause).self_group(against=self.operator)
+ )
else:
self.clauses.append(_literal_as_text(clause))
@@ -1875,14 +1907,18 @@ class ClauseList(ClauseElement):
"""
if not isinstance(other, ClauseList) and len(self.clauses) == 1:
return self.clauses[0].compare(other, **kw)
- elif isinstance(other, ClauseList) and \
- len(self.clauses) == len(other.clauses) and \
- self.operator is other.operator:
+ elif (
+ isinstance(other, ClauseList)
+ and len(self.clauses) == len(other.clauses)
+ and self.operator is other.operator
+ ):
if self.operator in (operators.and_, operators.or_):
completed = set()
for clause in self.clauses:
- for other_clause in set(other.clauses).difference(completed):
+ for other_clause in set(other.clauses).difference(
+ completed
+ ):
if clause.compare(other_clause, **kw):
completed.add(other_clause)
break
@@ -1898,11 +1934,12 @@ class ClauseList(ClauseElement):
class BooleanClauseList(ClauseList, ColumnElement):
- __visit_name__ = 'clauselist'
+ __visit_name__ = "clauselist"
def __init__(self, *arg, **kw):
raise NotImplementedError(
- "BooleanClauseList has a private constructor")
+ "BooleanClauseList has a private constructor"
+ )
@classmethod
def _construct(cls, operator, continue_on, skip_on, *clauses, **kw):
@@ -1910,8 +1947,7 @@ class BooleanClauseList(ClauseList, ColumnElement):
clauses = [
_expression_literal_as_text(clause)
- for clause in
- util.coerce_generator_arg(clauses)
+ for clause in util.coerce_generator_arg(clauses)
]
for clause in clauses:
@@ -1927,8 +1963,9 @@ class BooleanClauseList(ClauseList, ColumnElement):
elif not convert_clauses and clauses:
return clauses[0].self_group(against=operators._asbool)
- convert_clauses = [c.self_group(against=operator)
- for c in convert_clauses]
+ convert_clauses = [
+ c.self_group(against=operator) for c in convert_clauses
+ ]
self = cls.__new__(cls)
self.clauses = convert_clauses
@@ -2014,7 +2051,7 @@ class BooleanClauseList(ClauseList, ColumnElement):
@property
def _select_iterable(self):
- return (self, )
+ return (self,)
def self_group(self, against=None):
if not self.clauses:
@@ -2056,22 +2093,31 @@ class Tuple(ClauseList, ColumnElement):
clauses = [_literal_as_binds(c) for c in clauses]
self._type_tuple = [arg.type for arg in clauses]
- self.type = kw.pop('type_', self._type_tuple[0]
- if self._type_tuple else type_api.NULLTYPE)
+ self.type = kw.pop(
+ "type_",
+ self._type_tuple[0] if self._type_tuple else type_api.NULLTYPE,
+ )
super(Tuple, self).__init__(*clauses, **kw)
@property
def _select_iterable(self):
- return (self, )
+ return (self,)
def _bind_param(self, operator, obj, type_=None):
- return Tuple(*[
- BindParameter(None, o, _compared_to_operator=operator,
- _compared_to_type=compared_to_type, unique=True,
- type_=type_)
- for o, compared_to_type in zip(obj, self._type_tuple)
- ]).self_group()
+ return Tuple(
+ *[
+ BindParameter(
+ None,
+ o,
+ _compared_to_operator=operator,
+ _compared_to_type=compared_to_type,
+ unique=True,
+ type_=type_,
+ )
+ for o, compared_to_type in zip(obj, self._type_tuple)
+ ]
+ ).self_group()
class Case(ColumnElement):
@@ -2101,7 +2147,7 @@ class Case(ColumnElement):
"""
- __visit_name__ = 'case'
+ __visit_name__ = "case"
def __init__(self, whens, value=None, else_=None):
r"""Produce a ``CASE`` expression.
@@ -2231,13 +2277,13 @@ class Case(ColumnElement):
if value is not None:
whenlist = [
- (_literal_as_binds(c).self_group(),
- _literal_as_binds(r)) for (c, r) in whens
+ (_literal_as_binds(c).self_group(), _literal_as_binds(r))
+ for (c, r) in whens
]
else:
whenlist = [
- (_no_literals(c).self_group(),
- _literal_as_binds(r)) for (c, r) in whens
+ (_no_literals(c).self_group(), _literal_as_binds(r))
+ for (c, r) in whens
]
if whenlist:
@@ -2260,8 +2306,7 @@ class Case(ColumnElement):
def _copy_internals(self, clone=_clone, **kw):
if self.value is not None:
self.value = clone(self.value, **kw)
- self.whens = [(clone(x, **kw), clone(y, **kw))
- for x, y in self.whens]
+ self.whens = [(clone(x, **kw), clone(y, **kw)) for x, y in self.whens]
if self.else_ is not None:
self.else_ = clone(self.else_, **kw)
@@ -2276,8 +2321,9 @@ class Case(ColumnElement):
@property
def _from_objects(self):
- return list(itertools.chain(*[x._from_objects for x in
- self.get_children()]))
+ return list(
+ itertools.chain(*[x._from_objects for x in self.get_children()])
+ )
def literal_column(text, type_=None):
@@ -2333,7 +2379,7 @@ class Cast(ColumnElement):
"""
- __visit_name__ = 'cast'
+ __visit_name__ = "cast"
def __init__(self, expression, type_):
"""Produce a ``CAST`` expression.
@@ -2416,7 +2462,7 @@ class TypeCoerce(ColumnElement):
"""
- __visit_name__ = 'type_coerce'
+ __visit_name__ = "type_coerce"
def __init__(self, expression, type_):
"""Associate a SQL expression with a particular type, without rendering
@@ -2484,10 +2530,10 @@ class TypeCoerce(ColumnElement):
def _copy_internals(self, clone=_clone, **kw):
self.clause = clone(self.clause, **kw)
- self.__dict__.pop('typed_expression', None)
+ self.__dict__.pop("typed_expression", None)
def get_children(self, **kwargs):
- return self.clause,
+ return (self.clause,)
@property
def _from_objects(self):
@@ -2506,7 +2552,7 @@ class TypeCoerce(ColumnElement):
class Extract(ColumnElement):
"""Represent a SQL EXTRACT clause, ``extract(field FROM expr)``."""
- __visit_name__ = 'extract'
+ __visit_name__ = "extract"
def __init__(self, field, expr, **kwargs):
"""Return a :class:`.Extract` construct.
@@ -2524,7 +2570,7 @@ class Extract(ColumnElement):
self.expr = clone(self.expr, **kw)
def get_children(self, **kwargs):
- return self.expr,
+ return (self.expr,)
@property
def _from_objects(self):
@@ -2543,7 +2589,8 @@ class _label_reference(ColumnElement):
within an OVER clause.
"""
- __visit_name__ = 'label_reference'
+
+ __visit_name__ = "label_reference"
def __init__(self, element):
self.element = element
@@ -2557,7 +2604,7 @@ class _label_reference(ColumnElement):
class _textual_label_reference(ColumnElement):
- __visit_name__ = 'textual_label_reference'
+ __visit_name__ = "textual_label_reference"
def __init__(self, element):
self.element = element
@@ -2580,14 +2627,23 @@ class UnaryExpression(ColumnElement):
:func:`.nullsfirst` and :func:`.nullslast`.
"""
- __visit_name__ = 'unary'
- def __init__(self, element, operator=None, modifier=None,
- type_=None, negate=None, wraps_column_expression=False):
+ __visit_name__ = "unary"
+
+ def __init__(
+ self,
+ element,
+ operator=None,
+ modifier=None,
+ type_=None,
+ negate=None,
+ wraps_column_expression=False,
+ ):
self.operator = operator
self.modifier = modifier
self.element = element.self_group(
- against=self.operator or self.modifier)
+ against=self.operator or self.modifier
+ )
self.type = type_api.to_instance(type_)
self.negate = negate
self.wraps_column_expression = wraps_column_expression
@@ -2633,7 +2689,8 @@ class UnaryExpression(ColumnElement):
return UnaryExpression(
_literal_as_label_reference(column),
modifier=operators.nullsfirst_op,
- wraps_column_expression=False)
+ wraps_column_expression=False,
+ )
@classmethod
def _create_nullslast(cls, column):
@@ -2675,7 +2732,8 @@ class UnaryExpression(ColumnElement):
return UnaryExpression(
_literal_as_label_reference(column),
modifier=operators.nullslast_op,
- wraps_column_expression=False)
+ wraps_column_expression=False,
+ )
@classmethod
def _create_desc(cls, column):
@@ -2715,7 +2773,8 @@ class UnaryExpression(ColumnElement):
return UnaryExpression(
_literal_as_label_reference(column),
modifier=operators.desc_op,
- wraps_column_expression=False)
+ wraps_column_expression=False,
+ )
@classmethod
def _create_asc(cls, column):
@@ -2754,7 +2813,8 @@ class UnaryExpression(ColumnElement):
return UnaryExpression(
_literal_as_label_reference(column),
modifier=operators.asc_op,
- wraps_column_expression=False)
+ wraps_column_expression=False,
+ )
@classmethod
def _create_distinct(cls, expr):
@@ -2794,8 +2854,11 @@ class UnaryExpression(ColumnElement):
"""
expr = _literal_as_binds(expr)
return UnaryExpression(
- expr, operator=operators.distinct_op,
- type_=expr.type, wraps_column_expression=False)
+ expr,
+ operator=operators.distinct_op,
+ type_=expr.type,
+ wraps_column_expression=False,
+ )
@property
def _order_by_label_element(self):
@@ -2812,17 +2875,17 @@ class UnaryExpression(ColumnElement):
self.element = clone(self.element, **kw)
def get_children(self, **kwargs):
- return self.element,
+ return (self.element,)
def compare(self, other, **kw):
"""Compare this :class:`UnaryExpression` against the given
:class:`.ClauseElement`."""
return (
- isinstance(other, UnaryExpression) and
- self.operator == other.operator and
- self.modifier == other.modifier and
- self.element.compare(other.element, **kw)
+ isinstance(other, UnaryExpression)
+ and self.operator == other.operator
+ and self.modifier == other.modifier
+ and self.element.compare(other.element, **kw)
)
def _negate(self):
@@ -2833,14 +2896,16 @@ class UnaryExpression(ColumnElement):
negate=self.operator,
modifier=self.modifier,
type_=self.type,
- wraps_column_expression=self.wraps_column_expression)
+ wraps_column_expression=self.wraps_column_expression,
+ )
elif self.type._type_affinity is type_api.BOOLEANTYPE._type_affinity:
return UnaryExpression(
self.self_group(against=operators.inv),
operator=operators.inv,
type_=type_api.BOOLEANTYPE,
wraps_column_expression=self.wraps_column_expression,
- negate=None)
+ negate=None,
+ )
else:
return ClauseElement._negate(self)
@@ -2860,6 +2925,7 @@ class CollectionAggregate(UnaryExpression):
MySQL, they only work for subqueries.
"""
+
@classmethod
def _create_any(cls, expr):
"""Produce an ANY expression.
@@ -2883,12 +2949,15 @@ class CollectionAggregate(UnaryExpression):
expr = _literal_as_binds(expr)
- if expr.is_selectable and hasattr(expr, 'as_scalar'):
+ if expr.is_selectable and hasattr(expr, "as_scalar"):
expr = expr.as_scalar()
expr = expr.self_group()
return CollectionAggregate(
- expr, operator=operators.any_op,
- type_=type_api.NULLTYPE, wraps_column_expression=False)
+ expr,
+ operator=operators.any_op,
+ type_=type_api.NULLTYPE,
+ wraps_column_expression=False,
+ )
@classmethod
def _create_all(cls, expr):
@@ -2912,12 +2981,15 @@ class CollectionAggregate(UnaryExpression):
"""
expr = _literal_as_binds(expr)
- if expr.is_selectable and hasattr(expr, 'as_scalar'):
+ if expr.is_selectable and hasattr(expr, "as_scalar"):
expr = expr.as_scalar()
expr = expr.self_group()
return CollectionAggregate(
- expr, operator=operators.all_op,
- type_=type_api.NULLTYPE, wraps_column_expression=False)
+ expr,
+ operator=operators.all_op,
+ type_=type_api.NULLTYPE,
+ wraps_column_expression=False,
+ )
# operate and reverse_operate are hardwired to
# dispatch onto the type comparator directly, so that we can
@@ -2925,19 +2997,20 @@ class CollectionAggregate(UnaryExpression):
def operate(self, op, *other, **kwargs):
if not operators.is_comparison(op):
raise exc.ArgumentError(
- "Only comparison operators may be used with ANY/ALL")
- kwargs['reverse'] = True
+ "Only comparison operators may be used with ANY/ALL"
+ )
+ kwargs["reverse"] = True
return self.comparator.operate(operators.mirror(op), *other, **kwargs)
def reverse_operate(self, op, other, **kwargs):
# comparison operators should never call reverse_operate
assert not operators.is_comparison(op)
raise exc.ArgumentError(
- "Only comparison operators may be used with ANY/ALL")
+ "Only comparison operators may be used with ANY/ALL"
+ )
class AsBoolean(UnaryExpression):
-
def __init__(self, element, operator, negate):
self.element = element
self.type = type_api.BOOLEANTYPE
@@ -2971,7 +3044,7 @@ class BinaryExpression(ColumnElement):
"""
- __visit_name__ = 'binary'
+ __visit_name__ = "binary"
_is_implicitly_boolean = True
"""Indicates that any database will know this is a boolean expression
@@ -2979,8 +3052,9 @@ class BinaryExpression(ColumnElement):
"""
- def __init__(self, left, right, operator, type_=None,
- negate=None, modifiers=None):
+ def __init__(
+ self, left, right, operator, type_=None, negate=None, modifiers=None
+ ):
# allow compatibility with libraries that
# refer to BinaryExpression directly and pass strings
if isinstance(operator, util.string_types):
@@ -3026,15 +3100,15 @@ class BinaryExpression(ColumnElement):
given :class:`BinaryExpression`."""
return (
- isinstance(other, BinaryExpression) and
- self.operator == other.operator and
- (
- self.left.compare(other.left, **kw) and
- self.right.compare(other.right, **kw) or
- (
- operators.is_commutative(self.operator) and
- self.left.compare(other.right, **kw) and
- self.right.compare(other.left, **kw)
+ isinstance(other, BinaryExpression)
+ and self.operator == other.operator
+ and (
+ self.left.compare(other.left, **kw)
+ and self.right.compare(other.right, **kw)
+ or (
+ operators.is_commutative(self.operator)
+ and self.left.compare(other.right, **kw)
+ and self.right.compare(other.left, **kw)
)
)
)
@@ -3053,7 +3127,8 @@ class BinaryExpression(ColumnElement):
self.negate,
negate=self.operator,
type_=self.type,
- modifiers=self.modifiers)
+ modifiers=self.modifiers,
+ )
else:
return super(BinaryExpression, self)._negate()
@@ -3065,7 +3140,8 @@ class Slice(ColumnElement):
may be interpreted by specific dialects, e.g. PostgreSQL.
"""
- __visit_name__ = 'slice'
+
+ __visit_name__ = "slice"
def __init__(self, start, stop, step):
self.start = start
@@ -3081,17 +3157,18 @@ class Slice(ColumnElement):
class IndexExpression(BinaryExpression):
"""Represent the class of expressions that are like an "index" operation.
"""
+
pass
class Grouping(ColumnElement):
"""Represent a grouping within a column expression"""
- __visit_name__ = 'grouping'
+ __visit_name__ = "grouping"
def __init__(self, element):
self.element = element
- self.type = getattr(element, 'type', type_api.NULLTYPE)
+ self.type = getattr(element, "type", type_api.NULLTYPE)
def self_group(self, against=None):
return self
@@ -3106,13 +3183,13 @@ class Grouping(ColumnElement):
@property
def _label(self):
- return getattr(self.element, '_label', None) or self.anon_label
+ return getattr(self.element, "_label", None) or self.anon_label
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
def get_children(self, **kwargs):
- return self.element,
+ return (self.element,)
@property
def _from_objects(self):
@@ -3122,15 +3199,16 @@ class Grouping(ColumnElement):
return getattr(self.element, attr)
def __getstate__(self):
- return {'element': self.element, 'type': self.type}
+ return {"element": self.element, "type": self.type}
def __setstate__(self, state):
- self.element = state['element']
- self.type = state['type']
+ self.element = state["element"]
+ self.type = state["type"]
def compare(self, other, **kw):
- return isinstance(other, Grouping) and \
- self.element.compare(other.element)
+ return isinstance(other, Grouping) and self.element.compare(
+ other.element
+ )
RANGE_UNBOUNDED = util.symbol("RANGE_UNBOUNDED")
@@ -3147,14 +3225,15 @@ class Over(ColumnElement):
backends.
"""
- __visit_name__ = 'over'
+
+ __visit_name__ = "over"
order_by = None
partition_by = None
def __init__(
- self, element, partition_by=None,
- order_by=None, range_=None, rows=None):
+ self, element, partition_by=None, order_by=None, range_=None, rows=None
+ ):
"""Produce an :class:`.Over` object against a function.
Used against aggregate or so-called "window" functions,
@@ -3237,17 +3316,20 @@ class Over(ColumnElement):
if order_by is not None:
self.order_by = ClauseList(
*util.to_list(order_by),
- _literal_as_text=_literal_as_label_reference)
+ _literal_as_text=_literal_as_label_reference
+ )
if partition_by is not None:
self.partition_by = ClauseList(
*util.to_list(partition_by),
- _literal_as_text=_literal_as_label_reference)
+ _literal_as_text=_literal_as_label_reference
+ )
if range_:
self.range_ = self._interpret_range(range_)
if rows:
raise exc.ArgumentError(
- "'range_' and 'rows' are mutually exclusive")
+ "'range_' and 'rows' are mutually exclusive"
+ )
else:
self.rows = None
elif rows:
@@ -3267,7 +3349,8 @@ class Over(ColumnElement):
lower = int(range_[0])
except ValueError:
raise exc.ArgumentError(
- "Integer or None expected for range value")
+ "Integer or None expected for range value"
+ )
else:
if lower == 0:
lower = RANGE_CURRENT
@@ -3279,7 +3362,8 @@ class Over(ColumnElement):
upper = int(range_[1])
except ValueError:
raise exc.ArgumentError(
- "Integer or None expected for range value")
+ "Integer or None expected for range value"
+ )
else:
if upper == 0:
upper = RANGE_CURRENT
@@ -3303,9 +3387,11 @@ class Over(ColumnElement):
return self.element.type
def get_children(self, **kwargs):
- return [c for c in
- (self.element, self.partition_by, self.order_by)
- if c is not None]
+ return [
+ c
+ for c in (self.element, self.partition_by, self.order_by)
+ if c is not None
+ ]
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
@@ -3316,11 +3402,15 @@ class Over(ColumnElement):
@property
def _from_objects(self):
- return list(itertools.chain(
- *[c._from_objects for c in
- (self.element, self.partition_by, self.order_by)
- if c is not None]
- ))
+ return list(
+ itertools.chain(
+ *[
+ c._from_objects
+ for c in (self.element, self.partition_by, self.order_by)
+ if c is not None
+ ]
+ )
+ )
class WithinGroup(ColumnElement):
@@ -3339,7 +3429,8 @@ class WithinGroup(ColumnElement):
``None``, the function's ``.type`` is used.
"""
- __visit_name__ = 'withingroup'
+
+ __visit_name__ = "withingroup"
order_by = None
@@ -3383,7 +3474,8 @@ class WithinGroup(ColumnElement):
if order_by is not None:
self.order_by = ClauseList(
*util.to_list(order_by),
- _literal_as_text=_literal_as_label_reference)
+ _literal_as_text=_literal_as_label_reference
+ )
def over(self, partition_by=None, order_by=None, range_=None, rows=None):
"""Produce an OVER clause against this :class:`.WithinGroup`
@@ -3394,8 +3486,12 @@ class WithinGroup(ColumnElement):
"""
return Over(
- self, partition_by=partition_by, order_by=order_by,
- range_=range_, rows=rows)
+ self,
+ partition_by=partition_by,
+ order_by=order_by,
+ range_=range_,
+ rows=rows,
+ )
@util.memoized_property
def type(self):
@@ -3406,9 +3502,7 @@ class WithinGroup(ColumnElement):
return self.element.type
def get_children(self, **kwargs):
- return [c for c in
- (self.element, self.order_by)
- if c is not None]
+ return [c for c in (self.element, self.order_by) if c is not None]
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
@@ -3417,11 +3511,15 @@ class WithinGroup(ColumnElement):
@property
def _from_objects(self):
- return list(itertools.chain(
- *[c._from_objects for c in
- (self.element, self.order_by)
- if c is not None]
- ))
+ return list(
+ itertools.chain(
+ *[
+ c._from_objects
+ for c in (self.element, self.order_by)
+ if c is not None
+ ]
+ )
+ )
class FunctionFilter(ColumnElement):
@@ -3443,7 +3541,8 @@ class FunctionFilter(ColumnElement):
:meth:`.FunctionElement.filter`
"""
- __visit_name__ = 'funcfilter'
+
+ __visit_name__ = "funcfilter"
criterion = None
@@ -3515,17 +3614,19 @@ class FunctionFilter(ColumnElement):
"""
return Over(
- self, partition_by=partition_by, order_by=order_by,
- range_=range_, rows=rows)
+ self,
+ partition_by=partition_by,
+ order_by=order_by,
+ range_=range_,
+ rows=rows,
+ )
@util.memoized_property
def type(self):
return self.func.type
def get_children(self, **kwargs):
- return [c for c in
- (self.func, self.criterion)
- if c is not None]
+ return [c for c in (self.func, self.criterion) if c is not None]
def _copy_internals(self, clone=_clone, **kw):
self.func = clone(self.func, **kw)
@@ -3534,10 +3635,15 @@ class FunctionFilter(ColumnElement):
@property
def _from_objects(self):
- return list(itertools.chain(
- *[c._from_objects for c in (self.func, self.criterion)
- if c is not None]
- ))
+ return list(
+ itertools.chain(
+ *[
+ c._from_objects
+ for c in (self.func, self.criterion)
+ if c is not None
+ ]
+ )
+ )
class Label(ColumnElement):
@@ -3548,7 +3654,7 @@ class Label(ColumnElement):
"""
- __visit_name__ = 'label'
+ __visit_name__ = "label"
def __init__(self, name, element, type_=None):
"""Return a :class:`Label` object for the
@@ -3577,7 +3683,7 @@ class Label(ColumnElement):
self._resolve_label = self.name
else:
self.name = _anonymous_label(
- '%%(%d %s)s' % (id(self), getattr(element, 'name', 'anon'))
+ "%%(%d %s)s" % (id(self), getattr(element, "name", "anon"))
)
self.key = self._label = self._key_label = self.name
@@ -3603,7 +3709,7 @@ class Label(ColumnElement):
@util.memoized_property
def type(self):
return type_api.to_instance(
- self._type or getattr(self._element, 'type', None)
+ self._type or getattr(self._element, "type", None)
)
@util.memoized_property
@@ -3619,9 +3725,7 @@ class Label(ColumnElement):
def _apply_to_inner(self, fn, *arg, **kw):
sub_element = fn(*arg, **kw)
if sub_element is not self._element:
- return Label(self.name,
- sub_element,
- type_=self._type)
+ return Label(self.name, sub_element, type_=self._type)
else:
return self
@@ -3634,16 +3738,16 @@ class Label(ColumnElement):
return self.element.foreign_keys
def get_children(self, **kwargs):
- return self.element,
+ return (self.element,)
def _copy_internals(self, clone=_clone, anonymize_labels=False, **kw):
self._element = clone(self._element, **kw)
- self.__dict__.pop('element', None)
- self.__dict__.pop('_allow_label_resolve', None)
+ self.__dict__.pop("element", None)
+ self.__dict__.pop("_allow_label_resolve", None)
if anonymize_labels:
self.name = self._resolve_label = _anonymous_label(
- '%%(%d %s)s' % (
- id(self), getattr(self.element, 'name', 'anon'))
+ "%%(%d %s)s"
+ % (id(self), getattr(self.element, "name", "anon"))
)
self.key = self._label = self._key_label = self.name
@@ -3652,8 +3756,9 @@ class Label(ColumnElement):
return self.element._from_objects
def _make_proxy(self, selectable, name=None, **kw):
- e = self.element._make_proxy(selectable,
- name=name if name else self.name)
+ e = self.element._make_proxy(
+ selectable, name=name if name else self.name
+ )
e._proxies.append(self)
if self._type is not None:
e.type = self._type
@@ -3694,7 +3799,8 @@ class ColumnClause(Immutable, ColumnElement):
:class:`.Column`
"""
- __visit_name__ = 'column'
+
+ __visit_name__ = "column"
onupdate = default = server_default = server_onupdate = None
@@ -3792,25 +3898,33 @@ class ColumnClause(Immutable, ColumnElement):
self.is_literal = is_literal
def _compare_name_for_result(self, other):
- if self.is_literal or \
- self.table is None or self.table._textual or \
- not hasattr(other, 'proxy_set') or (
- isinstance(other, ColumnClause) and
- (other.is_literal or
- other.table is None or
- other.table._textual)
- ):
- return (hasattr(other, 'name') and self.name == other.name) or \
- (hasattr(other, '_label') and self._label == other._label)
+ if (
+ self.is_literal
+ or self.table is None
+ or self.table._textual
+ or not hasattr(other, "proxy_set")
+ or (
+ isinstance(other, ColumnClause)
+ and (
+ other.is_literal
+ or other.table is None
+ or other.table._textual
+ )
+ )
+ ):
+ return (hasattr(other, "name") and self.name == other.name) or (
+ hasattr(other, "_label") and self._label == other._label
+ )
else:
return other.proxy_set.intersection(self.proxy_set)
def _get_table(self):
- return self.__dict__['table']
+ return self.__dict__["table"]
def _set_table(self, table):
self._memoized_property.expire_instance(self)
- self.__dict__['table'] = table
+ self.__dict__["table"] = table
+
table = property(_get_table, _set_table)
@_memoized_property
@@ -3826,7 +3940,7 @@ class ColumnClause(Immutable, ColumnElement):
if util.py3k:
return self.name
else:
- return self.name.encode('ascii', 'backslashreplace')
+ return self.name.encode("ascii", "backslashreplace")
@_memoized_property
def _key_label(self):
@@ -3850,9 +3964,8 @@ class ColumnClause(Immutable, ColumnElement):
return None
elif t is not None and t.named_with_column:
- if getattr(t, 'schema', None):
- label = t.schema.replace('.', '_') + "_" + \
- t.name + "_" + name
+ if getattr(t, "schema", None):
+ label = t.schema.replace(".", "_") + "_" + t.name + "_" + name
else:
label = t.name + "_" + name
@@ -3884,31 +3997,39 @@ class ColumnClause(Immutable, ColumnElement):
return name
def _bind_param(self, operator, obj, type_=None):
- return BindParameter(self.key, obj,
- _compared_to_operator=operator,
- _compared_to_type=self.type,
- type_=type_,
- unique=True)
-
- def _make_proxy(self, selectable, name=None, attach=True,
- name_is_truncatable=False, **kw):
+ return BindParameter(
+ self.key,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ type_=type_,
+ unique=True,
+ )
+
+ def _make_proxy(
+ self,
+ selectable,
+ name=None,
+ attach=True,
+ name_is_truncatable=False,
+ **kw
+ ):
# propagate the "is_literal" flag only if we are keeping our name,
# otherwise its considered to be a label
is_literal = self.is_literal and (name is None or name == self.name)
c = self._constructor(
- _as_truncated(name or self.name) if
- name_is_truncatable else
- (name or self.name),
+ _as_truncated(name or self.name)
+ if name_is_truncatable
+ else (name or self.name),
type_=self.type,
_selectable=selectable,
- is_literal=is_literal
+ is_literal=is_literal,
)
if name is None:
c.key = self.key
c._proxies = [self]
if selectable._is_clone_of is not None:
- c._is_clone_of = \
- selectable._is_clone_of.columns.get(c.key)
+ c._is_clone_of = selectable._is_clone_of.columns.get(c.key)
if attach:
selectable._columns[c.key] = c
@@ -3924,24 +4045,25 @@ class CollationClause(ColumnElement):
class _IdentifiedClause(Executable, ClauseElement):
- __visit_name__ = 'identified'
- _execution_options = \
- Executable._execution_options.union({'autocommit': False})
+ __visit_name__ = "identified"
+ _execution_options = Executable._execution_options.union(
+ {"autocommit": False}
+ )
def __init__(self, ident):
self.ident = ident
class SavepointClause(_IdentifiedClause):
- __visit_name__ = 'savepoint'
+ __visit_name__ = "savepoint"
class RollbackToSavepointClause(_IdentifiedClause):
- __visit_name__ = 'rollback_to_savepoint'
+ __visit_name__ = "rollback_to_savepoint"
class ReleaseSavepointClause(_IdentifiedClause):
- __visit_name__ = 'release_savepoint'
+ __visit_name__ = "release_savepoint"
class quoted_name(util.MemoizedSlots, util.text_type):
@@ -3992,7 +4114,7 @@ class quoted_name(util.MemoizedSlots, util.text_type):
"""
- __slots__ = 'quote', 'lower', 'upper'
+ __slots__ = "quote", "lower", "upper"
def __new__(cls, value, quote):
if value is None:
@@ -4026,9 +4148,9 @@ class quoted_name(util.MemoizedSlots, util.text_type):
return util.text_type(self).upper()
def __repr__(self):
- backslashed = self.encode('ascii', 'backslashreplace')
+ backslashed = self.encode("ascii", "backslashreplace")
if not util.py2k:
- backslashed = backslashed.decode('ascii')
+ backslashed = backslashed.decode("ascii")
return "'%s'" % backslashed
@@ -4094,6 +4216,7 @@ class conv(_truncated_label):
:ref:`constraint_naming_conventions`
"""
+
__slots__ = ()
@@ -4102,6 +4225,7 @@ class _defer_name(_truncated_label):
generation.
"""
+
__slots__ = ()
def __new__(cls, value):
@@ -4113,13 +4237,15 @@ class _defer_name(_truncated_label):
return super(_defer_name, cls).__new__(cls, value)
def __reduce__(self):
- return self.__class__, (util.text_type(self), )
+ return self.__class__, (util.text_type(self),)
class _defer_none_name(_defer_name):
"""indicate a 'deferred' name that was ultimately the value None."""
+
__slots__ = ()
+
_NONE_NAME = _defer_none_name("_unnamed_")
# for backwards compatibility in case
@@ -4138,15 +4264,15 @@ class _anonymous_label(_truncated_label):
def __add__(self, other):
return _anonymous_label(
quoted_name(
- util.text_type.__add__(self, util.text_type(other)),
- self.quote)
+ util.text_type.__add__(self, util.text_type(other)), self.quote
+ )
)
def __radd__(self, other):
return _anonymous_label(
quoted_name(
- util.text_type.__add__(util.text_type(other), self),
- self.quote)
+ util.text_type.__add__(util.text_type(other), self), self.quote
+ )
)
def apply_map(self, map_):
@@ -4206,20 +4332,23 @@ def _cloned_intersection(a, b):
"""
all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
- return set(elem for elem in a
- if all_overlap.intersection(elem._cloned_set))
+ return set(
+ elem for elem in a if all_overlap.intersection(elem._cloned_set)
+ )
def _cloned_difference(a, b):
all_overlap = set(_expand_cloned(a)).intersection(_expand_cloned(b))
- return set(elem for elem in a
- if not all_overlap.intersection(elem._cloned_set))
+ return set(
+ elem for elem in a if not all_overlap.intersection(elem._cloned_set)
+ )
@util.dependencies("sqlalchemy.sql.functions")
def _labeled(functions, element):
- if not hasattr(element, 'name') or \
- isinstance(element, functions.FunctionElement):
+ if not hasattr(element, "name") or isinstance(
+ element, functions.FunctionElement
+ ):
return element.label(None)
else:
return element
@@ -4235,7 +4364,7 @@ def _find_columns(clause):
"""locate Column objects within the given expression."""
cols = util.column_set()
- traverse(clause, {}, {'column': cols.add})
+ traverse(clause, {}, {"column": cols.add})
return cols
@@ -4253,7 +4382,7 @@ def _find_columns(clause):
def _column_as_key(element):
if isinstance(element, util.string_types):
return element
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
element = element.__clause_element__()
try:
return element.key
@@ -4262,7 +4391,7 @@ def _column_as_key(element):
def _clause_element_as_expr(element):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
return element.__clause_element__()
else:
return element
@@ -4272,7 +4401,7 @@ def _literal_as_label_reference(element):
if isinstance(element, util.string_types):
return _textual_label_reference(element)
- elif hasattr(element, '__clause_element__'):
+ elif hasattr(element, "__clause_element__"):
element = element.__clause_element__()
return _literal_as_text(element)
@@ -4282,11 +4411,13 @@ def _literal_and_labels_as_label_reference(element):
if isinstance(element, util.string_types):
return _textual_label_reference(element)
- elif hasattr(element, '__clause_element__'):
+ elif hasattr(element, "__clause_element__"):
element = element.__clause_element__()
- if isinstance(element, ColumnElement) and \
- element._order_by_label_element is not None:
+ if (
+ isinstance(element, ColumnElement)
+ and element._order_by_label_element is not None
+ ):
return _label_reference(element)
else:
return _literal_as_text(element)
@@ -4299,14 +4430,15 @@ def _expression_literal_as_text(element):
def _literal_as_text(element, warn=False):
if isinstance(element, Visitable):
return element
- elif hasattr(element, '__clause_element__'):
+ elif hasattr(element, "__clause_element__"):
return element.__clause_element__()
elif isinstance(element, util.string_types):
if warn:
util.warn_limited(
"Textual SQL expression %(expr)r should be "
"explicitly declared as text(%(expr)r)",
- {"expr": util.ellipses_string(element)})
+ {"expr": util.ellipses_string(element)},
+ )
return TextClause(util.text_type(element))
elif isinstance(element, (util.NoneType, bool)):
@@ -4319,20 +4451,23 @@ def _literal_as_text(element, warn=False):
def _no_literals(element):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
return element.__clause_element__()
elif not isinstance(element, Visitable):
- raise exc.ArgumentError("Ambiguous literal: %r. Use the 'text()' "
- "function to indicate a SQL expression "
- "literal, or 'literal()' to indicate a "
- "bound value." % (element, ))
+ raise exc.ArgumentError(
+ "Ambiguous literal: %r. Use the 'text()' "
+ "function to indicate a SQL expression "
+ "literal, or 'literal()' to indicate a "
+ "bound value." % (element,)
+ )
else:
return element
def _is_literal(element):
- return not isinstance(element, Visitable) and \
- not hasattr(element, '__clause_element__')
+ return not isinstance(element, Visitable) and not hasattr(
+ element, "__clause_element__"
+ )
def _only_column_elements_or_none(element, name):
@@ -4343,17 +4478,18 @@ def _only_column_elements_or_none(element, name):
def _only_column_elements(element, name):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
element = element.__clause_element__()
if not isinstance(element, ColumnElement):
raise exc.ArgumentError(
"Column-based expression object expected for argument "
- "'%s'; got: '%s', type %s" % (name, element, type(element)))
+ "'%s'; got: '%s', type %s" % (name, element, type(element))
+ )
return element
def _literal_as_binds(element, name=None, type_=None):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
return element.__clause_element__()
elif not isinstance(element, Visitable):
if element is None:
@@ -4363,13 +4499,14 @@ def _literal_as_binds(element, name=None, type_=None):
else:
return element
-_guess_straight_column = re.compile(r'^\w\S*$', re.I)
+
+_guess_straight_column = re.compile(r"^\w\S*$", re.I)
def _interpret_as_column_or_from(element):
if isinstance(element, Visitable):
return element
- elif hasattr(element, '__clause_element__'):
+ elif hasattr(element, "__clause_element__"):
return element.__clause_element__()
insp = inspection.inspect(element, raiseerr=False)
@@ -4399,11 +4536,11 @@ def _interpret_as_column_or_from(element):
{
"column": util.ellipses_string(element),
"literal_column": "literal_column"
- if guess_is_literal else "column"
- })
- return ColumnClause(
- element,
- is_literal=guess_is_literal)
+ if guess_is_literal
+ else "column",
+ },
+ )
+ return ColumnClause(element, is_literal=guess_is_literal)
def _const_expr(element):
@@ -4416,9 +4553,7 @@ def _const_expr(element):
elif element is True:
return True_()
else:
- raise exc.ArgumentError(
- "Expected None, False, or True"
- )
+ raise exc.ArgumentError("Expected None, False, or True")
def _type_from_args(args):
@@ -4429,18 +4564,15 @@ def _type_from_args(args):
return type_api.NULLTYPE
-def _corresponding_column_or_error(fromclause, column,
- require_embedded=False):
- c = fromclause.corresponding_column(column,
- require_embedded=require_embedded)
+def _corresponding_column_or_error(fromclause, column, require_embedded=False):
+ c = fromclause.corresponding_column(
+ column, require_embedded=require_embedded
+ )
if c is None:
raise exc.InvalidRequestError(
"Given column '%s', attached to table '%s', "
"failed to locate a corresponding column from table '%s'"
- %
- (column,
- getattr(column, 'table', None),
- fromclause.description)
+ % (column, getattr(column, "table", None), fromclause.description)
)
return c
@@ -4449,7 +4581,7 @@ class AnnotatedColumnElement(Annotated):
def __init__(self, element, values):
Annotated.__init__(self, element, values)
ColumnElement.comparator._reset(self)
- for attr in ('name', 'key', 'table'):
+ for attr in ("name", "key", "table"):
if self.__dict__.get(attr, False) is None:
self.__dict__.pop(attr)
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index b69b6ee8c..aab9f46d4 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -15,43 +15,142 @@ class.
"""
__all__ = [
- 'Alias', 'any_', 'all_', 'ClauseElement', 'ColumnCollection', 'ColumnElement',
- 'CompoundSelect', 'Delete', 'FromClause', 'Insert', 'Join', 'Lateral',
- 'Select',
- 'Selectable', 'TableClause', 'Update', 'alias', 'and_', 'asc', 'between',
- 'bindparam', 'case', 'cast', 'column', 'delete', 'desc', 'distinct',
- 'except_', 'except_all', 'exists', 'extract', 'func', 'modifier',
- 'collate', 'insert', 'intersect', 'intersect_all', 'join', 'label',
- 'lateral', 'literal', 'literal_column', 'not_', 'null', 'nullsfirst',
- 'nullslast',
- 'or_', 'outparam', 'outerjoin', 'over', 'select', 'subquery',
- 'table', 'text',
- 'tuple_', 'type_coerce', 'quoted_name', 'union', 'union_all', 'update',
- 'within_group',
- 'TableSample', 'tablesample']
+ "Alias",
+ "any_",
+ "all_",
+ "ClauseElement",
+ "ColumnCollection",
+ "ColumnElement",
+ "CompoundSelect",
+ "Delete",
+ "FromClause",
+ "Insert",
+ "Join",
+ "Lateral",
+ "Select",
+ "Selectable",
+ "TableClause",
+ "Update",
+ "alias",
+ "and_",
+ "asc",
+ "between",
+ "bindparam",
+ "case",
+ "cast",
+ "column",
+ "delete",
+ "desc",
+ "distinct",
+ "except_",
+ "except_all",
+ "exists",
+ "extract",
+ "func",
+ "modifier",
+ "collate",
+ "insert",
+ "intersect",
+ "intersect_all",
+ "join",
+ "label",
+ "lateral",
+ "literal",
+ "literal_column",
+ "not_",
+ "null",
+ "nullsfirst",
+ "nullslast",
+ "or_",
+ "outparam",
+ "outerjoin",
+ "over",
+ "select",
+ "subquery",
+ "table",
+ "text",
+ "tuple_",
+ "type_coerce",
+ "quoted_name",
+ "union",
+ "union_all",
+ "update",
+ "within_group",
+ "TableSample",
+ "tablesample",
+]
from .visitors import Visitable
from .functions import func, modifier, FunctionElement, Function
from ..util.langhelpers import public_factory
-from .elements import ClauseElement, ColumnElement,\
- BindParameter, CollectionAggregate, UnaryExpression, BooleanClauseList, \
- Label, Cast, Case, ColumnClause, TextClause, Over, Null, \
- True_, False_, BinaryExpression, Tuple, TypeClause, Extract, \
- Grouping, WithinGroup, not_, quoted_name, \
- collate, literal_column, between,\
- literal, outparam, TypeCoerce, ClauseList, FunctionFilter
+from .elements import (
+ ClauseElement,
+ ColumnElement,
+ BindParameter,
+ CollectionAggregate,
+ UnaryExpression,
+ BooleanClauseList,
+ Label,
+ Cast,
+ Case,
+ ColumnClause,
+ TextClause,
+ Over,
+ Null,
+ True_,
+ False_,
+ BinaryExpression,
+ Tuple,
+ TypeClause,
+ Extract,
+ Grouping,
+ WithinGroup,
+ not_,
+ quoted_name,
+ collate,
+ literal_column,
+ between,
+ literal,
+ outparam,
+ TypeCoerce,
+ ClauseList,
+ FunctionFilter,
+)
-from .elements import SavepointClause, RollbackToSavepointClause, \
- ReleaseSavepointClause
+from .elements import (
+ SavepointClause,
+ RollbackToSavepointClause,
+ ReleaseSavepointClause,
+)
-from .base import ColumnCollection, Generative, Executable, \
- PARSE_AUTOCOMMIT
+from .base import ColumnCollection, Generative, Executable, PARSE_AUTOCOMMIT
-from .selectable import Alias, Join, Select, Selectable, TableClause, \
- CompoundSelect, CTE, FromClause, FromGrouping, Lateral, SelectBase, \
- alias, GenerativeSelect, subquery, HasCTE, HasPrefixes, HasSuffixes, \
- lateral, Exists, ScalarSelect, TextAsFrom, TableSample, tablesample
+from .selectable import (
+ Alias,
+ Join,
+ Select,
+ Selectable,
+ TableClause,
+ CompoundSelect,
+ CTE,
+ FromClause,
+ FromGrouping,
+ Lateral,
+ SelectBase,
+ alias,
+ GenerativeSelect,
+ subquery,
+ HasCTE,
+ HasPrefixes,
+ HasSuffixes,
+ lateral,
+ Exists,
+ ScalarSelect,
+ TextAsFrom,
+ TableSample,
+ tablesample,
+)
from .dml import Insert, Update, Delete, UpdateBase, ValuesBase
@@ -79,23 +178,30 @@ extract = public_factory(Extract, ".expression.extract")
tuple_ = public_factory(Tuple, ".expression.tuple_")
except_ = public_factory(CompoundSelect._create_except, ".expression.except_")
except_all = public_factory(
- CompoundSelect._create_except_all, ".expression.except_all")
+ CompoundSelect._create_except_all, ".expression.except_all"
+)
intersect = public_factory(
- CompoundSelect._create_intersect, ".expression.intersect")
+ CompoundSelect._create_intersect, ".expression.intersect"
+)
intersect_all = public_factory(
- CompoundSelect._create_intersect_all, ".expression.intersect_all")
+ CompoundSelect._create_intersect_all, ".expression.intersect_all"
+)
union = public_factory(CompoundSelect._create_union, ".expression.union")
union_all = public_factory(
- CompoundSelect._create_union_all, ".expression.union_all")
+ CompoundSelect._create_union_all, ".expression.union_all"
+)
exists = public_factory(Exists, ".expression.exists")
nullsfirst = public_factory(
- UnaryExpression._create_nullsfirst, ".expression.nullsfirst")
+ UnaryExpression._create_nullsfirst, ".expression.nullsfirst"
+)
nullslast = public_factory(
- UnaryExpression._create_nullslast, ".expression.nullslast")
+ UnaryExpression._create_nullslast, ".expression.nullslast"
+)
asc = public_factory(UnaryExpression._create_asc, ".expression.asc")
desc = public_factory(UnaryExpression._create_desc, ".expression.desc")
distinct = public_factory(
- UnaryExpression._create_distinct, ".expression.distinct")
+ UnaryExpression._create_distinct, ".expression.distinct"
+)
type_coerce = public_factory(TypeCoerce, ".expression.type_coerce")
true = public_factory(True_._instance, ".expression.true")
false = public_factory(False_._instance, ".expression.false")
@@ -105,19 +211,30 @@ outerjoin = public_factory(Join._create_outerjoin, ".expression.outerjoin")
insert = public_factory(Insert, ".expression.insert")
update = public_factory(Update, ".expression.update")
delete = public_factory(Delete, ".expression.delete")
-funcfilter = public_factory(
- FunctionFilter, ".expression.funcfilter")
+funcfilter = public_factory(FunctionFilter, ".expression.funcfilter")
# internal functions still being called from tests and the ORM,
# these might be better off in some other namespace
from .base import _from_objects
-from .elements import _literal_as_text, _clause_element_as_expr,\
- _is_column, _labeled, _only_column_elements, _string_or_unprintable, \
- _truncated_label, _clone, _cloned_difference, _cloned_intersection,\
- _column_as_key, _literal_as_binds, _select_iterables, \
- _corresponding_column_or_error, _literal_as_label_reference, \
- _expression_literal_as_text
+from .elements import (
+ _literal_as_text,
+ _clause_element_as_expr,
+ _is_column,
+ _labeled,
+ _only_column_elements,
+ _string_or_unprintable,
+ _truncated_label,
+ _clone,
+ _cloned_difference,
+ _cloned_intersection,
+ _column_as_key,
+ _literal_as_binds,
+ _select_iterables,
+ _corresponding_column_or_error,
+ _literal_as_label_reference,
+ _expression_literal_as_text,
+)
from .selectable import _interpret_as_from
diff --git a/lib/sqlalchemy/sql/functions.py b/lib/sqlalchemy/sql/functions.py
index 4b4d2d463..883bb8cc3 100644
--- a/lib/sqlalchemy/sql/functions.py
+++ b/lib/sqlalchemy/sql/functions.py
@@ -10,10 +10,22 @@
"""
from . import sqltypes, schema
from .base import Executable, ColumnCollection
-from .elements import ClauseList, Cast, Extract, _literal_as_binds, \
- literal_column, _type_from_args, ColumnElement, _clone,\
- Over, BindParameter, FunctionFilter, Grouping, WithinGroup, \
- BinaryExpression
+from .elements import (
+ ClauseList,
+ Cast,
+ Extract,
+ _literal_as_binds,
+ literal_column,
+ _type_from_args,
+ ColumnElement,
+ _clone,
+ Over,
+ BindParameter,
+ FunctionFilter,
+ Grouping,
+ WithinGroup,
+ BinaryExpression,
+)
from .selectable import FromClause, Select, Alias
from . import util as sqlutil
from . import operators
@@ -62,9 +74,8 @@ class FunctionElement(Executable, ColumnElement, FromClause):
args = [_literal_as_binds(c, self.name) for c in clauses]
self._has_args = self._has_args or bool(args)
self.clause_expr = ClauseList(
- operator=operators.comma_op,
- group_contents=True, *args).\
- self_group()
+ operator=operators.comma_op, group_contents=True, *args
+ ).self_group()
def _execute_on_connection(self, connection, multiparams, params):
return connection._execute_function(self, multiparams, params)
@@ -123,7 +134,7 @@ class FunctionElement(Executable, ColumnElement, FromClause):
partition_by=partition_by,
order_by=order_by,
rows=rows,
- range_=range_
+ range_=range_,
)
def within_group(self, *order_by):
@@ -233,16 +244,14 @@ class FunctionElement(Executable, ColumnElement, FromClause):
.. versionadded:: 1.3
"""
- return FunctionAsBinary(
- self, left_index, right_index
- )
+ return FunctionAsBinary(self, left_index, right_index)
@property
def _from_objects(self):
return self.clauses._from_objects
def get_children(self, **kwargs):
- return self.clause_expr,
+ return (self.clause_expr,)
def _copy_internals(self, clone=_clone, **kw):
self.clause_expr = clone(self.clause_expr, **kw)
@@ -336,24 +345,29 @@ class FunctionElement(Executable, ColumnElement, FromClause):
return self.select().execute()
def _bind_param(self, operator, obj, type_=None):
- return BindParameter(None, obj, _compared_to_operator=operator,
- _compared_to_type=self.type, unique=True,
- type_=type_)
+ return BindParameter(
+ None,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ unique=True,
+ type_=type_,
+ )
def self_group(self, against=None):
# for the moment, we are parenthesizing all array-returning
# expressions against getitem. This may need to be made
# more portable if in the future we support other DBs
# besides postgresql.
- if against is operators.getitem and \
- isinstance(self.type, sqltypes.ARRAY):
+ if against is operators.getitem and isinstance(
+ self.type, sqltypes.ARRAY
+ ):
return Grouping(self)
else:
return super(FunctionElement, self).self_group(against=against)
class FunctionAsBinary(BinaryExpression):
-
def __init__(self, fn, left_index, right_index):
left = fn.clauses.clauses[left_index - 1]
right = fn.clauses.clauses[right_index - 1]
@@ -362,8 +376,11 @@ class FunctionAsBinary(BinaryExpression):
self.right_index = right_index
super(FunctionAsBinary, self).__init__(
- left, right, operators.function_as_comparison_op,
- type_=sqltypes.BOOLEANTYPE)
+ left,
+ right,
+ operators.function_as_comparison_op,
+ type_=sqltypes.BOOLEANTYPE,
+ )
@property
def left(self):
@@ -382,7 +399,7 @@ class FunctionAsBinary(BinaryExpression):
self.sql_function.clauses.clauses[self.right_index - 1] = value
def _copy_internals(self, **kw):
- clone = kw.pop('clone')
+ clone = kw.pop("clone")
self.sql_function = clone(self.sql_function, **kw)
super(FunctionAsBinary, self)._copy_internals(**kw)
@@ -396,13 +413,13 @@ class _FunctionGenerator(object):
def __getattr__(self, name):
# passthru __ attributes; fixes pydoc
- if name.startswith('__'):
+ if name.startswith("__"):
try:
return self.__dict__[name]
except KeyError:
raise AttributeError(name)
- elif name.endswith('_'):
+ elif name.endswith("_"):
name = name[0:-1]
f = _FunctionGenerator(**self.opts)
f.__names = list(self.__names) + [name]
@@ -426,8 +443,9 @@ class _FunctionGenerator(object):
if func is not None:
return func(*c, **o)
- return Function(self.__names[-1],
- packagenames=self.__names[0:-1], *c, **o)
+ return Function(
+ self.__names[-1], packagenames=self.__names[0:-1], *c, **o
+ )
func = _FunctionGenerator()
@@ -523,7 +541,7 @@ class Function(FunctionElement):
"""
- __visit_name__ = 'function'
+ __visit_name__ = "function"
def __init__(self, name, *clauses, **kw):
"""Construct a :class:`.Function`.
@@ -532,30 +550,33 @@ class Function(FunctionElement):
new :class:`.Function` instances.
"""
- self.packagenames = kw.pop('packagenames', None) or []
+ self.packagenames = kw.pop("packagenames", None) or []
self.name = name
- self._bind = kw.get('bind', None)
- self.type = sqltypes.to_instance(kw.get('type_', None))
+ self._bind = kw.get("bind", None)
+ self.type = sqltypes.to_instance(kw.get("type_", None))
FunctionElement.__init__(self, *clauses, **kw)
def _bind_param(self, operator, obj, type_=None):
- return BindParameter(self.name, obj,
- _compared_to_operator=operator,
- _compared_to_type=self.type,
- type_=type_,
- unique=True)
+ return BindParameter(
+ self.name,
+ obj,
+ _compared_to_operator=operator,
+ _compared_to_type=self.type,
+ type_=type_,
+ unique=True,
+ )
class _GenericMeta(VisitableType):
def __init__(cls, clsname, bases, clsdict):
if annotation.Annotated not in cls.__mro__:
- cls.name = name = clsdict.get('name', clsname)
- cls.identifier = identifier = clsdict.get('identifier', name)
- package = clsdict.pop('package', '_default')
+ cls.name = name = clsdict.get("name", clsname)
+ cls.identifier = identifier = clsdict.get("identifier", name)
+ package = clsdict.pop("package", "_default")
# legacy
- if '__return_type__' in clsdict:
- cls.type = clsdict['__return_type__']
+ if "__return_type__" in clsdict:
+ cls.type = clsdict["__return_type__"]
register_function(identifier, cls, package)
super(_GenericMeta, cls).__init__(clsname, bases, clsdict)
@@ -635,17 +656,19 @@ class GenericFunction(util.with_metaclass(_GenericMeta, Function)):
coerce_arguments = True
def __init__(self, *args, **kwargs):
- parsed_args = kwargs.pop('_parsed_args', None)
+ parsed_args = kwargs.pop("_parsed_args", None)
if parsed_args is None:
parsed_args = [_literal_as_binds(c, self.name) for c in args]
self._has_args = self._has_args or bool(parsed_args)
self.packagenames = []
- self._bind = kwargs.get('bind', None)
+ self._bind = kwargs.get("bind", None)
self.clause_expr = ClauseList(
- operator=operators.comma_op,
- group_contents=True, *parsed_args).self_group()
+ operator=operators.comma_op, group_contents=True, *parsed_args
+ ).self_group()
self.type = sqltypes.to_instance(
- kwargs.pop("type_", None) or getattr(self, 'type', None))
+ kwargs.pop("type_", None) or getattr(self, "type", None)
+ )
+
register_function("cast", Cast)
register_function("extract", Extract)
@@ -660,13 +683,15 @@ class next_value(GenericFunction):
that does not provide support for sequences.
"""
+
type = sqltypes.Integer()
name = "next_value"
def __init__(self, seq, **kw):
- assert isinstance(seq, schema.Sequence), \
- "next_value() accepts a Sequence object as input."
- self._bind = kw.get('bind', None)
+ assert isinstance(
+ seq, schema.Sequence
+ ), "next_value() accepts a Sequence object as input."
+ self._bind = kw.get("bind", None)
self.sequence = seq
@property
@@ -684,8 +709,8 @@ class ReturnTypeFromArgs(GenericFunction):
def __init__(self, *args, **kwargs):
args = [_literal_as_binds(c, self.name) for c in args]
- kwargs.setdefault('type_', _type_from_args(args))
- kwargs['_parsed_args'] = args
+ kwargs.setdefault("type_", _type_from_args(args))
+ kwargs["_parsed_args"] = args
super(ReturnTypeFromArgs, self).__init__(*args, **kwargs)
@@ -733,7 +758,7 @@ class count(GenericFunction):
def __init__(self, expression=None, **kwargs):
if expression is None:
- expression = literal_column('*')
+ expression = literal_column("*")
super(count, self).__init__(expression, **kwargs)
@@ -797,15 +822,15 @@ class array_agg(GenericFunction):
def __init__(self, *args, **kwargs):
args = [_literal_as_binds(c) for c in args]
- default_array_type = kwargs.pop('_default_array_type', sqltypes.ARRAY)
- if 'type_' not in kwargs:
+ default_array_type = kwargs.pop("_default_array_type", sqltypes.ARRAY)
+ if "type_" not in kwargs:
type_from_args = _type_from_args(args)
if isinstance(type_from_args, sqltypes.ARRAY):
- kwargs['type_'] = type_from_args
+ kwargs["type_"] = type_from_args
else:
- kwargs['type_'] = default_array_type(type_from_args)
- kwargs['_parsed_args'] = args
+ kwargs["type_"] = default_array_type(type_from_args)
+ kwargs["_parsed_args"] = args
super(array_agg, self).__init__(*args, **kwargs)
@@ -883,6 +908,7 @@ class rank(GenericFunction):
.. versionadded:: 1.1
"""
+
type = sqltypes.Integer()
@@ -897,6 +923,7 @@ class dense_rank(GenericFunction):
.. versionadded:: 1.1
"""
+
type = sqltypes.Integer()
@@ -911,6 +938,7 @@ class percent_rank(GenericFunction):
.. versionadded:: 1.1
"""
+
type = sqltypes.Numeric()
@@ -925,6 +953,7 @@ class cume_dist(GenericFunction):
.. versionadded:: 1.1
"""
+
type = sqltypes.Numeric()
diff --git a/lib/sqlalchemy/sql/naming.py b/lib/sqlalchemy/sql/naming.py
index 0107ce724..144cc4dfc 100644
--- a/lib/sqlalchemy/sql/naming.py
+++ b/lib/sqlalchemy/sql/naming.py
@@ -10,8 +10,16 @@
"""
-from .schema import Constraint, ForeignKeyConstraint, PrimaryKeyConstraint, \
- UniqueConstraint, CheckConstraint, Index, Table, Column
+from .schema import (
+ Constraint,
+ ForeignKeyConstraint,
+ PrimaryKeyConstraint,
+ UniqueConstraint,
+ CheckConstraint,
+ Index,
+ Table,
+ Column,
+)
from .. import event, events
from .. import exc
from .elements import _truncated_label, _defer_name, _defer_none_name, conv
@@ -19,7 +27,6 @@ import re
class ConventionDict(object):
-
def __init__(self, const, table, convention):
self.const = const
self._is_fk = isinstance(const, ForeignKeyConstraint)
@@ -79,8 +86,8 @@ class ConventionDict(object):
def __getitem__(self, key):
if key in self.convention:
return self.convention[key](self.const, self.table)
- elif hasattr(self, '_key_%s' % key):
- return getattr(self, '_key_%s' % key)()
+ elif hasattr(self, "_key_%s" % key):
+ return getattr(self, "_key_%s" % key)()
else:
col_template = re.match(r".*_?column_(\d+)(_?N)?_.+", key)
if col_template:
@@ -108,12 +115,13 @@ class ConventionDict(object):
return getattr(self, attr)(idx)
raise KeyError(key)
+
_prefix_dict = {
Index: "ix",
PrimaryKeyConstraint: "pk",
CheckConstraint: "ck",
UniqueConstraint: "uq",
- ForeignKeyConstraint: "fk"
+ ForeignKeyConstraint: "fk",
}
@@ -134,15 +142,18 @@ def _constraint_name_for_table(const, table):
if isinstance(const.name, conv):
return const.name
- elif convention is not None and \
- not isinstance(const.name, conv) and \
- (
- const.name is None or
- "constraint_name" in convention or
- isinstance(const.name, _defer_name)):
+ elif (
+ convention is not None
+ and not isinstance(const.name, conv)
+ and (
+ const.name is None
+ or "constraint_name" in convention
+ or isinstance(const.name, _defer_name)
+ )
+ ):
return conv(
- convention % ConventionDict(const, table,
- metadata.naming_convention)
+ convention
+ % ConventionDict(const, table, metadata.naming_convention)
)
elif isinstance(convention, _defer_none_name):
return None
@@ -155,9 +166,11 @@ def _constraint_name(const, table):
# for column-attached constraint, set another event
# to link the column attached to the table as this constraint
# associated with the table.
- event.listen(table, "after_parent_attach",
- lambda col, table: _constraint_name(const, table)
- )
+ event.listen(
+ table,
+ "after_parent_attach",
+ lambda col, table: _constraint_name(const, table),
+ )
elif isinstance(table, Table):
if isinstance(const.name, (conv, _defer_name)):
return
diff --git a/lib/sqlalchemy/sql/operators.py b/lib/sqlalchemy/sql/operators.py
index 5b4a28a06..2b843d751 100644
--- a/lib/sqlalchemy/sql/operators.py
+++ b/lib/sqlalchemy/sql/operators.py
@@ -13,8 +13,25 @@
from .. import util
from operator import (
- and_, or_, inv, add, mul, sub, mod, truediv, lt, le, ne, gt, ge, eq, neg,
- getitem, lshift, rshift, contains
+ and_,
+ or_,
+ inv,
+ add,
+ mul,
+ sub,
+ mod,
+ truediv,
+ lt,
+ le,
+ ne,
+ gt,
+ ge,
+ eq,
+ neg,
+ getitem,
+ lshift,
+ rshift,
+ contains,
)
if util.py2k:
@@ -37,6 +54,7 @@ class Operators(object):
:class:`.ColumnOperators`.
"""
+
__slots__ = ()
def __and__(self, other):
@@ -105,8 +123,8 @@ class Operators(object):
return self.operate(inv)
def op(
- self, opstring, precedence=0, is_comparison=False,
- return_type=None):
+ self, opstring, precedence=0, is_comparison=False, return_type=None
+ ):
"""produce a generic operator function.
e.g.::
@@ -168,6 +186,7 @@ class Operators(object):
def against(other):
return operator(self, other)
+
return against
def bool_op(self, opstring, precedence=0):
@@ -247,12 +266,18 @@ class custom_op(object):
:meth:`.Operators.bool_op`
"""
- __name__ = 'custom_op'
+
+ __name__ = "custom_op"
def __init__(
- self, opstring, precedence=0, is_comparison=False,
- return_type=None, natural_self_precedent=False,
- eager_grouping=False):
+ self,
+ opstring,
+ precedence=0,
+ is_comparison=False,
+ return_type=None,
+ natural_self_precedent=False,
+ eager_grouping=False,
+ ):
self.opstring = opstring
self.precedence = precedence
self.is_comparison = is_comparison
@@ -263,8 +288,7 @@ class custom_op(object):
)
def __eq__(self, other):
- return isinstance(other, custom_op) and \
- other.opstring == self.opstring
+ return isinstance(other, custom_op) and other.opstring == self.opstring
def __hash__(self):
return id(self)
@@ -1138,6 +1162,7 @@ class ColumnOperators(Operators):
"""
return self.reverse_operate(truediv, other)
+
_commutative = {eq, ne, add, mul}
_comparison = {eq, ne, lt, gt, ge, le}
@@ -1261,20 +1286,18 @@ def _escaped_like_impl(fn, other, escape, autoescape):
if autoescape:
if autoescape is not True:
util.warn(
- "The autoescape parameter is now a simple boolean True/False")
+ "The autoescape parameter is now a simple boolean True/False"
+ )
if escape is None:
- escape = '/'
+ escape = "/"
if not isinstance(other, util.compat.string_types):
raise TypeError("String value expected when autoescape=True")
- if escape not in ('%', '_'):
+ if escape not in ("%", "_"):
other = other.replace(escape, escape + escape)
- other = (
- other.replace('%', escape + '%').
- replace('_', escape + '_')
- )
+ other = other.replace("%", escape + "%").replace("_", escape + "_")
return fn(other, escape=escape)
@@ -1362,8 +1385,7 @@ def json_path_getitem_op(a, b):
def is_comparison(op):
- return op in _comparison or \
- isinstance(op, custom_op) and op.is_comparison
+ return op in _comparison or isinstance(op, custom_op) and op.is_comparison
def is_commutative(op):
@@ -1371,13 +1393,16 @@ def is_commutative(op):
def is_ordering_modifier(op):
- return op in (asc_op, desc_op,
- nullsfirst_op, nullslast_op)
+ return op in (asc_op, desc_op, nullsfirst_op, nullslast_op)
def is_natural_self_precedent(op):
- return op in _natural_self_precedent or \
- isinstance(op, custom_op) and op.natural_self_precedent
+ return (
+ op in _natural_self_precedent
+ or isinstance(op, custom_op)
+ and op.natural_self_precedent
+ )
+
_booleans = (inv, istrue, isfalse, and_, or_)
@@ -1385,12 +1410,8 @@ _booleans = (inv, istrue, isfalse, and_, or_)
def is_boolean(op):
return is_comparison(op) or op in _booleans
-_mirror = {
- gt: lt,
- ge: le,
- lt: gt,
- le: ge
-}
+
+_mirror = {gt: lt, ge: le, lt: gt, le: ge}
def mirror(op):
@@ -1404,17 +1425,18 @@ def mirror(op):
_associative = _commutative.union([concat_op, and_, or_]).difference([eq, ne])
-_natural_self_precedent = _associative.union([
- getitem, json_getitem_op, json_path_getitem_op])
+_natural_self_precedent = _associative.union(
+ [getitem, json_getitem_op, json_path_getitem_op]
+)
"""Operators where if we have (a op b) op c, we don't want to
parenthesize (a op b).
"""
-_asbool = util.symbol('_asbool', canonical=-10)
-_smallest = util.symbol('_smallest', canonical=-100)
-_largest = util.symbol('_largest', canonical=100)
+_asbool = util.symbol("_asbool", canonical=-10)
+_smallest = util.symbol("_smallest", canonical=-100)
+_largest = util.symbol("_largest", canonical=100)
_PRECEDENCE = {
from_: 15,
@@ -1424,7 +1446,6 @@ _PRECEDENCE = {
getitem: 15,
json_getitem_op: 15,
json_path_getitem_op: 15,
-
mul: 8,
truediv: 8,
div: 8,
@@ -1432,22 +1453,17 @@ _PRECEDENCE = {
neg: 8,
add: 7,
sub: 7,
-
concat_op: 6,
-
match_op: 5,
notmatch_op: 5,
-
ilike_op: 5,
notilike_op: 5,
like_op: 5,
notlike_op: 5,
in_op: 5,
notin_op: 5,
-
is_: 5,
isnot: 5,
-
eq: 5,
ne: 5,
is_distinct_from: 5,
@@ -1458,7 +1474,6 @@ _PRECEDENCE = {
lt: 5,
ge: 5,
le: 5,
-
between_op: 5,
notbetween_op: 5,
distinct_op: 5,
@@ -1468,17 +1483,14 @@ _PRECEDENCE = {
and_: 3,
or_: 2,
comma_op: -1,
-
desc_op: 3,
asc_op: 3,
collate: 4,
-
as_: -1,
exists: 0,
-
_asbool: -10,
_smallest: _smallest,
- _largest: _largest
+ _largest: _largest,
}
@@ -1486,7 +1498,6 @@ def is_precedent(operator, against):
if operator is against and is_natural_self_precedent(operator):
return False
else:
- return (_PRECEDENCE.get(operator,
- getattr(operator, 'precedence', _smallest)) <=
- _PRECEDENCE.get(against,
- getattr(against, 'precedence', _largest)))
+ return _PRECEDENCE.get(
+ operator, getattr(operator, "precedence", _smallest)
+ ) <= _PRECEDENCE.get(against, getattr(against, "precedence", _largest))
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 3e9aa174a..d6c3f5000 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -36,25 +36,31 @@ import operator
from . import visitors
from . import type_api
from .base import _bind_or_error, ColumnCollection
-from .elements import ClauseElement, ColumnClause, \
- _as_truncated, TextClause, _literal_as_text,\
- ColumnElement, quoted_name
+from .elements import (
+ ClauseElement,
+ ColumnClause,
+ _as_truncated,
+ TextClause,
+ _literal_as_text,
+ ColumnElement,
+ quoted_name,
+)
from .selectable import TableClause
import collections
import sqlalchemy
from . import ddl
-RETAIN_SCHEMA = util.symbol('retain_schema')
+RETAIN_SCHEMA = util.symbol("retain_schema")
BLANK_SCHEMA = util.symbol(
- 'blank_schema',
+ "blank_schema",
"""Symbol indicating that a :class:`.Table` or :class:`.Sequence`
should have 'None' for its schema, even if the parent
:class:`.MetaData` has specified a schema.
.. versionadded:: 1.0.14
- """
+ """,
)
@@ -69,11 +75,15 @@ def _get_table_key(name, schema):
# break an import cycle
def _copy_expression(expression, source_table, target_table):
def replace(col):
- if isinstance(col, Column) and \
- col.table is source_table and col.key in source_table.c:
+ if (
+ isinstance(col, Column)
+ and col.table is source_table
+ and col.key in source_table.c
+ ):
return target_table.c[col.key]
else:
return None
+
return visitors.replacement_traverse(expression, {}, replace)
@@ -81,7 +91,7 @@ def _copy_expression(expression, source_table, target_table):
class SchemaItem(SchemaEventTarget, visitors.Visitable):
"""Base class for items that define a database schema."""
- __visit_name__ = 'schema_item'
+ __visit_name__ = "schema_item"
def _init_items(self, *args):
"""Initialize the list of child items for this SchemaItem."""
@@ -95,10 +105,10 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable):
return []
def __repr__(self):
- return util.generic_repr(self, omit_kwarg=['info'])
+ return util.generic_repr(self, omit_kwarg=["info"])
@property
- @util.deprecated('0.9', 'Use ``<obj>.name.quote``')
+ @util.deprecated("0.9", "Use ``<obj>.name.quote``")
def quote(self):
"""Return the value of the ``quote`` flag passed
to this schema object, for those schema items which
@@ -121,7 +131,7 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable):
return {}
def _schema_item_copy(self, schema_item):
- if 'info' in self.__dict__:
+ if "info" in self.__dict__:
schema_item.info = self.info.copy()
schema_item.dispatch._update(self.dispatch)
return schema_item
@@ -396,7 +406,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
"""
- __visit_name__ = 'table'
+ __visit_name__ = "table"
def __new__(cls, *args, **kw):
if not args:
@@ -408,26 +418,26 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
except IndexError:
raise TypeError("Table() takes at least two arguments")
- schema = kw.get('schema', None)
+ schema = kw.get("schema", None)
if schema is None:
schema = metadata.schema
elif schema is BLANK_SCHEMA:
schema = None
- keep_existing = kw.pop('keep_existing', False)
- extend_existing = kw.pop('extend_existing', False)
- if 'useexisting' in kw:
+ keep_existing = kw.pop("keep_existing", False)
+ extend_existing = kw.pop("extend_existing", False)
+ if "useexisting" in kw:
msg = "useexisting is deprecated. Use extend_existing."
util.warn_deprecated(msg)
if extend_existing:
msg = "useexisting is synonymous with extend_existing."
raise exc.ArgumentError(msg)
- extend_existing = kw.pop('useexisting', False)
+ extend_existing = kw.pop("useexisting", False)
if keep_existing and extend_existing:
msg = "keep_existing and extend_existing are mutually exclusive."
raise exc.ArgumentError(msg)
- mustexist = kw.pop('mustexist', False)
+ mustexist = kw.pop("mustexist", False)
key = _get_table_key(name, schema)
if key in metadata.tables:
if not keep_existing and not extend_existing and bool(args):
@@ -436,15 +446,15 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
"instance. Specify 'extend_existing=True' "
"to redefine "
"options and columns on an "
- "existing Table object." % key)
+ "existing Table object." % key
+ )
table = metadata.tables[key]
if extend_existing:
table._init_existing(*args, **kw)
return table
else:
if mustexist:
- raise exc.InvalidRequestError(
- "Table '%s' not defined" % (key))
+ raise exc.InvalidRequestError("Table '%s' not defined" % (key))
table = object.__new__(cls)
table.dispatch.before_parent_attach(table, metadata)
metadata._add_table(name, schema, table)
@@ -457,7 +467,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
metadata._remove_table(name, schema)
@property
- @util.deprecated('0.9', 'Use ``table.schema.quote``')
+ @util.deprecated("0.9", "Use ``table.schema.quote``")
def quote_schema(self):
"""Return the value of the ``quote_schema`` flag passed
to this :class:`.Table`.
@@ -478,23 +488,25 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
def _init(self, name, metadata, *args, **kwargs):
super(Table, self).__init__(
- quoted_name(name, kwargs.pop('quote', None)))
+ quoted_name(name, kwargs.pop("quote", None))
+ )
self.metadata = metadata
- self.schema = kwargs.pop('schema', None)
+ self.schema = kwargs.pop("schema", None)
if self.schema is None:
self.schema = metadata.schema
elif self.schema is BLANK_SCHEMA:
self.schema = None
else:
- quote_schema = kwargs.pop('quote_schema', None)
+ quote_schema = kwargs.pop("quote_schema", None)
self.schema = quoted_name(self.schema, quote_schema)
self.indexes = set()
self.constraints = set()
self._columns = ColumnCollection()
- PrimaryKeyConstraint(_implicit_generated=True).\
- _set_parent_with_dispatch(self)
+ PrimaryKeyConstraint(
+ _implicit_generated=True
+ )._set_parent_with_dispatch(self)
self.foreign_keys = set()
self._extra_dependencies = set()
if self.schema is not None:
@@ -502,26 +514,26 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
else:
self.fullname = self.name
- autoload_with = kwargs.pop('autoload_with', None)
- autoload = kwargs.pop('autoload', autoload_with is not None)
+ autoload_with = kwargs.pop("autoload_with", None)
+ autoload = kwargs.pop("autoload", autoload_with is not None)
# this argument is only used with _init_existing()
- kwargs.pop('autoload_replace', True)
+ kwargs.pop("autoload_replace", True)
_extend_on = kwargs.pop("_extend_on", None)
- include_columns = kwargs.pop('include_columns', None)
+ include_columns = kwargs.pop("include_columns", None)
- self.implicit_returning = kwargs.pop('implicit_returning', True)
+ self.implicit_returning = kwargs.pop("implicit_returning", True)
- self.comment = kwargs.pop('comment', None)
+ self.comment = kwargs.pop("comment", None)
- if 'info' in kwargs:
- self.info = kwargs.pop('info')
- if 'listeners' in kwargs:
- listeners = kwargs.pop('listeners')
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
+ if "listeners" in kwargs:
+ listeners = kwargs.pop("listeners")
for evt, fn in listeners:
event.listen(self, evt, fn)
- self._prefixes = kwargs.pop('prefixes', [])
+ self._prefixes = kwargs.pop("prefixes", [])
self._extra_kwargs(**kwargs)
@@ -530,21 +542,29 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
# circular foreign keys
if autoload:
self._autoload(
- metadata, autoload_with,
- include_columns, _extend_on=_extend_on)
+ metadata, autoload_with, include_columns, _extend_on=_extend_on
+ )
# initialize all the column, etc. objects. done after reflection to
# allow user-overrides
self._init_items(*args)
- def _autoload(self, metadata, autoload_with, include_columns,
- exclude_columns=(), _extend_on=None):
+ def _autoload(
+ self,
+ metadata,
+ autoload_with,
+ include_columns,
+ exclude_columns=(),
+ _extend_on=None,
+ ):
if autoload_with:
autoload_with.run_callable(
autoload_with.dialect.reflecttable,
- self, include_columns, exclude_columns,
- _extend_on=_extend_on
+ self,
+ include_columns,
+ exclude_columns,
+ _extend_on=_extend_on,
)
else:
bind = _bind_or_error(
@@ -553,11 +573,14 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
"Pass an engine to the Table via "
"autoload_with=<someengine>, "
"or associate the MetaData with an engine via "
- "metadata.bind=<someengine>")
+ "metadata.bind=<someengine>",
+ )
bind.run_callable(
bind.dialect.reflecttable,
- self, include_columns, exclude_columns,
- _extend_on=_extend_on
+ self,
+ include_columns,
+ exclude_columns,
+ _extend_on=_extend_on,
)
@property
@@ -582,34 +605,36 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
return set(fkc.constraint for fkc in self.foreign_keys)
def _init_existing(self, *args, **kwargs):
- autoload_with = kwargs.pop('autoload_with', None)
- autoload = kwargs.pop('autoload', autoload_with is not None)
- autoload_replace = kwargs.pop('autoload_replace', True)
- schema = kwargs.pop('schema', None)
- _extend_on = kwargs.pop('_extend_on', None)
+ autoload_with = kwargs.pop("autoload_with", None)
+ autoload = kwargs.pop("autoload", autoload_with is not None)
+ autoload_replace = kwargs.pop("autoload_replace", True)
+ schema = kwargs.pop("schema", None)
+ _extend_on = kwargs.pop("_extend_on", None)
if schema and schema != self.schema:
raise exc.ArgumentError(
"Can't change schema of existing table from '%s' to '%s'",
- (self.schema, schema))
+ (self.schema, schema),
+ )
- include_columns = kwargs.pop('include_columns', None)
+ include_columns = kwargs.pop("include_columns", None)
if include_columns is not None:
for c in self.c:
if c.name not in include_columns:
self._columns.remove(c)
- for key in ('quote', 'quote_schema'):
+ for key in ("quote", "quote_schema"):
if key in kwargs:
raise exc.ArgumentError(
- "Can't redefine 'quote' or 'quote_schema' arguments")
+ "Can't redefine 'quote' or 'quote_schema' arguments"
+ )
- if 'comment' in kwargs:
- self.comment = kwargs.pop('comment', None)
+ if "comment" in kwargs:
+ self.comment = kwargs.pop("comment", None)
- if 'info' in kwargs:
- self.info = kwargs.pop('info')
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
if autoload:
if not autoload_replace:
@@ -620,8 +645,12 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
else:
exclude_columns = ()
self._autoload(
- self.metadata, autoload_with,
- include_columns, exclude_columns, _extend_on=_extend_on)
+ self.metadata,
+ autoload_with,
+ include_columns,
+ exclude_columns,
+ _extend_on=_extend_on,
+ )
self._extra_kwargs(**kwargs)
self._init_items(*args)
@@ -653,10 +682,12 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
return _get_table_key(self.name, self.schema)
def __repr__(self):
- return "Table(%s)" % ', '.join(
- [repr(self.name)] + [repr(self.metadata)] +
- [repr(x) for x in self.columns] +
- ["%s=%s" % (k, repr(getattr(self, k))) for k in ['schema']])
+ return "Table(%s)" % ", ".join(
+ [repr(self.name)]
+ + [repr(self.metadata)]
+ + [repr(x) for x in self.columns]
+ + ["%s=%s" % (k, repr(getattr(self, k))) for k in ["schema"]]
+ )
def __str__(self):
return _get_table_key(self.description, self.schema)
@@ -735,17 +766,19 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
def adapt_listener(target, connection, **kw):
listener(event_name, target, connection)
- event.listen(self, "" + event_name.replace('-', '_'), adapt_listener)
+ event.listen(self, "" + event_name.replace("-", "_"), adapt_listener)
def _set_parent(self, metadata):
metadata._add_table(self.name, self.schema, self)
self.metadata = metadata
- def get_children(self, column_collections=True,
- schema_visitor=False, **kw):
+ def get_children(
+ self, column_collections=True, schema_visitor=False, **kw
+ ):
if not schema_visitor:
return TableClause.get_children(
- self, column_collections=column_collections, **kw)
+ self, column_collections=column_collections, **kw
+ )
else:
if column_collections:
return list(self.columns)
@@ -758,8 +791,9 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
if bind is None:
bind = _bind_or_error(self)
- return bind.run_callable(bind.dialect.has_table,
- self.name, schema=self.schema)
+ return bind.run_callable(
+ bind.dialect.has_table, self.name, schema=self.schema
+ )
def create(self, bind=None, checkfirst=False):
"""Issue a ``CREATE`` statement for this
@@ -774,9 +808,7 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaGenerator,
- self,
- checkfirst=checkfirst)
+ bind._run_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst)
def drop(self, bind=None, checkfirst=False):
"""Issue a ``DROP`` statement for this
@@ -790,12 +822,15 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
"""
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaDropper,
- self,
- checkfirst=checkfirst)
-
- def tometadata(self, metadata, schema=RETAIN_SCHEMA,
- referred_schema_fn=None, name=None):
+ bind._run_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst)
+
+ def tometadata(
+ self,
+ metadata,
+ schema=RETAIN_SCHEMA,
+ referred_schema_fn=None,
+ name=None,
+ ):
"""Return a copy of this :class:`.Table` associated with a different
:class:`.MetaData`.
@@ -868,29 +903,37 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
schema = metadata.schema
key = _get_table_key(name, schema)
if key in metadata.tables:
- util.warn("Table '%s' already exists within the given "
- "MetaData - not copying." % self.description)
+ util.warn(
+ "Table '%s' already exists within the given "
+ "MetaData - not copying." % self.description
+ )
return metadata.tables[key]
args = []
for c in self.columns:
args.append(c.copy(schema=schema))
table = Table(
- name, metadata, schema=schema,
+ name,
+ metadata,
+ schema=schema,
comment=self.comment,
- *args, **self.kwargs
+ *args,
+ **self.kwargs
)
for c in self.constraints:
if isinstance(c, ForeignKeyConstraint):
referred_schema = c._referred_schema
if referred_schema_fn:
fk_constraint_schema = referred_schema_fn(
- self, schema, c, referred_schema)
+ self, schema, c, referred_schema
+ )
else:
fk_constraint_schema = (
- schema if referred_schema == self.schema else None)
+ schema if referred_schema == self.schema else None
+ )
table.append_constraint(
- c.copy(schema=fk_constraint_schema, target_table=table))
+ c.copy(schema=fk_constraint_schema, target_table=table)
+ )
elif not c._type_bound:
# skip unique constraints that would be generated
# by the 'unique' flag on Column
@@ -898,25 +941,30 @@ class Table(DialectKWArgs, SchemaItem, TableClause):
continue
table.append_constraint(
- c.copy(schema=schema, target_table=table))
+ c.copy(schema=schema, target_table=table)
+ )
for index in self.indexes:
# skip indexes that would be generated
# by the 'index' flag on Column
if index._column_flag:
continue
- Index(index.name,
- unique=index.unique,
- *[_copy_expression(expr, self, table)
- for expr in index.expressions],
- _table=table,
- **index.kwargs)
+ Index(
+ index.name,
+ unique=index.unique,
+ *[
+ _copy_expression(expr, self, table)
+ for expr in index.expressions
+ ],
+ _table=table,
+ **index.kwargs
+ )
return self._schema_item_copy(table)
class Column(DialectKWArgs, SchemaItem, ColumnClause):
"""Represents a column in a database table."""
- __visit_name__ = 'column'
+ __visit_name__ = "column"
def __init__(self, *args, **kwargs):
r"""
@@ -1192,14 +1240,15 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
"""
- name = kwargs.pop('name', None)
- type_ = kwargs.pop('type_', None)
+ name = kwargs.pop("name", None)
+ type_ = kwargs.pop("type_", None)
args = list(args)
if args:
if isinstance(args[0], util.string_types):
if name is not None:
raise exc.ArgumentError(
- "May not pass name positionally and as a keyword.")
+ "May not pass name positionally and as a keyword."
+ )
name = args.pop(0)
if args:
coltype = args[0]
@@ -1207,40 +1256,42 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
if hasattr(coltype, "_sqla_type"):
if type_ is not None:
raise exc.ArgumentError(
- "May not pass type_ positionally and as a keyword.")
+ "May not pass type_ positionally and as a keyword."
+ )
type_ = args.pop(0)
if name is not None:
- name = quoted_name(name, kwargs.pop('quote', None))
+ name = quoted_name(name, kwargs.pop("quote", None))
elif "quote" in kwargs:
- raise exc.ArgumentError("Explicit 'name' is required when "
- "sending 'quote' argument")
+ raise exc.ArgumentError(
+ "Explicit 'name' is required when " "sending 'quote' argument"
+ )
super(Column, self).__init__(name, type_)
- self.key = kwargs.pop('key', name)
- self.primary_key = kwargs.pop('primary_key', False)
- self.nullable = kwargs.pop('nullable', not self.primary_key)
- self.default = kwargs.pop('default', None)
- self.server_default = kwargs.pop('server_default', None)
- self.server_onupdate = kwargs.pop('server_onupdate', None)
+ self.key = kwargs.pop("key", name)
+ self.primary_key = kwargs.pop("primary_key", False)
+ self.nullable = kwargs.pop("nullable", not self.primary_key)
+ self.default = kwargs.pop("default", None)
+ self.server_default = kwargs.pop("server_default", None)
+ self.server_onupdate = kwargs.pop("server_onupdate", None)
# these default to None because .index and .unique is *not*
# an informational flag about Column - there can still be an
# Index or UniqueConstraint referring to this Column.
- self.index = kwargs.pop('index', None)
- self.unique = kwargs.pop('unique', None)
+ self.index = kwargs.pop("index", None)
+ self.unique = kwargs.pop("unique", None)
- self.system = kwargs.pop('system', False)
- self.doc = kwargs.pop('doc', None)
- self.onupdate = kwargs.pop('onupdate', None)
- self.autoincrement = kwargs.pop('autoincrement', "auto")
+ self.system = kwargs.pop("system", False)
+ self.doc = kwargs.pop("doc", None)
+ self.onupdate = kwargs.pop("onupdate", None)
+ self.autoincrement = kwargs.pop("autoincrement", "auto")
self.constraints = set()
self.foreign_keys = set()
- self.comment = kwargs.pop('comment', None)
+ self.comment = kwargs.pop("comment", None)
# check if this Column is proxying another column
- if '_proxies' in kwargs:
- self._proxies = kwargs.pop('_proxies')
+ if "_proxies" in kwargs:
+ self._proxies = kwargs.pop("_proxies")
# otherwise, add DDL-related events
elif isinstance(self.type, SchemaEventTarget):
self.type._set_parent_with_dispatch(self)
@@ -1249,14 +1300,13 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
if isinstance(self.default, (ColumnDefault, Sequence)):
args.append(self.default)
else:
- if getattr(self.type, '_warn_on_bytestring', False):
+ if getattr(self.type, "_warn_on_bytestring", False):
if isinstance(self.default, util.binary_type):
util.warn(
"Unicode column '%s' has non-unicode "
- "default value %r specified." % (
- self.key,
- self.default
- ))
+ "default value %r specified."
+ % (self.key, self.default)
+ )
args.append(ColumnDefault(self.default))
if self.server_default is not None:
@@ -1275,30 +1325,31 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
if isinstance(self.server_onupdate, FetchedValue):
args.append(self.server_onupdate._as_for_update(True))
else:
- args.append(DefaultClause(self.server_onupdate,
- for_update=True))
+ args.append(
+ DefaultClause(self.server_onupdate, for_update=True)
+ )
self._init_items(*args)
util.set_creation_order(self)
- if 'info' in kwargs:
- self.info = kwargs.pop('info')
+ if "info" in kwargs:
+ self.info = kwargs.pop("info")
self._extra_kwargs(**kwargs)
def _extra_kwargs(self, **kwargs):
self._validate_dialect_kwargs(kwargs)
-# @property
-# def quote(self):
-# return getattr(self.name, "quote", None)
+ # @property
+ # def quote(self):
+ # return getattr(self.name, "quote", None)
def __str__(self):
if self.name is None:
return "(no name)"
elif self.table is not None:
if self.table.named_with_column:
- return (self.table.description + "." + self.description)
+ return self.table.description + "." + self.description
else:
return self.description
else:
@@ -1320,40 +1371,47 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
def __repr__(self):
kwarg = []
if self.key != self.name:
- kwarg.append('key')
+ kwarg.append("key")
if self.primary_key:
- kwarg.append('primary_key')
+ kwarg.append("primary_key")
if not self.nullable:
- kwarg.append('nullable')
+ kwarg.append("nullable")
if self.onupdate:
- kwarg.append('onupdate')
+ kwarg.append("onupdate")
if self.default:
- kwarg.append('default')
+ kwarg.append("default")
if self.server_default:
- kwarg.append('server_default')
- return "Column(%s)" % ', '.join(
- [repr(self.name)] + [repr(self.type)] +
- [repr(x) for x in self.foreign_keys if x is not None] +
- [repr(x) for x in self.constraints] +
- [(self.table is not None and "table=<%s>" %
- self.table.description or "table=None")] +
- ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg])
+ kwarg.append("server_default")
+ return "Column(%s)" % ", ".join(
+ [repr(self.name)]
+ + [repr(self.type)]
+ + [repr(x) for x in self.foreign_keys if x is not None]
+ + [repr(x) for x in self.constraints]
+ + [
+ (
+ self.table is not None
+ and "table=<%s>" % self.table.description
+ or "table=None"
+ )
+ ]
+ + ["%s=%s" % (k, repr(getattr(self, k))) for k in kwarg]
+ )
def _set_parent(self, table):
if not self.name:
raise exc.ArgumentError(
"Column must be constructed with a non-blank name or "
- "assign a non-blank .name before adding to a Table.")
+ "assign a non-blank .name before adding to a Table."
+ )
if self.key is None:
self.key = self.name
- existing = getattr(self, 'table', None)
+ existing = getattr(self, "table", None)
if existing is not None and existing is not table:
raise exc.ArgumentError(
- "Column object '%s' already assigned to Table '%s'" % (
- self.key,
- existing.description
- ))
+ "Column object '%s' already assigned to Table '%s'"
+ % (self.key, existing.description)
+ )
if self.key in table._columns:
col = table._columns.get(self.key)
@@ -1373,8 +1431,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
elif self.key in table.primary_key:
raise exc.ArgumentError(
"Trying to redefine primary-key column '%s' as a "
- "non-primary-key column on table '%s'" % (
- self.key, table.fullname))
+ "non-primary-key column on table '%s'"
+ % (self.key, table.fullname)
+ )
self.table = table
@@ -1383,7 +1442,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
raise exc.ArgumentError(
"The 'index' keyword argument on Column is boolean only. "
"To create indexes with a specific name, create an "
- "explicit Index object external to the Table.")
+ "explicit Index object external to the Table."
+ )
Index(None, self, unique=bool(self.unique), _column_flag=True)
elif self.unique:
if isinstance(self.unique, util.string_types):
@@ -1392,9 +1452,11 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
"only. To create unique constraints or indexes with a "
"specific name, append an explicit UniqueConstraint to "
"the Table's list of elements, or create an explicit "
- "Index object external to the Table.")
+ "Index object external to the Table."
+ )
table.append_constraint(
- UniqueConstraint(self.key, _column_flag=True))
+ UniqueConstraint(self.key, _column_flag=True)
+ )
self._setup_on_memoized_fks(lambda fk: fk._set_remote_table(table))
@@ -1413,7 +1475,7 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
if self.table is not None:
fn(self, self.table)
else:
- event.listen(self, 'after_parent_attach', fn)
+ event.listen(self, "after_parent_attach", fn)
def copy(self, **kw):
"""Create a copy of this ``Column``, unitialized.
@@ -1423,9 +1485,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
"""
# Constraint objects plus non-constraint-bound ForeignKey objects
- args = \
- [c.copy(**kw) for c in self.constraints if not c._type_bound] + \
- [c.copy(**kw) for c in self.foreign_keys if not c.constraint]
+ args = [
+ c.copy(**kw) for c in self.constraints if not c._type_bound
+ ] + [c.copy(**kw) for c in self.foreign_keys if not c.constraint]
type_ = self.type
if isinstance(type_, SchemaEventTarget):
@@ -1452,8 +1514,9 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
)
return self._schema_item_copy(c)
- def _make_proxy(self, selectable, name=None, key=None,
- name_is_truncatable=False, **kw):
+ def _make_proxy(
+ self, selectable, name=None, key=None, name_is_truncatable=False, **kw
+ ):
"""Create a *proxy* for this column.
This is a copy of this ``Column`` referenced by a different parent
@@ -1462,22 +1525,28 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
information is not transferred.
"""
- fk = [ForeignKey(f.column, _constraint=f.constraint)
- for f in self.foreign_keys]
+ fk = [
+ ForeignKey(f.column, _constraint=f.constraint)
+ for f in self.foreign_keys
+ ]
if name is None and self.name is None:
raise exc.InvalidRequestError(
"Cannot initialize a sub-selectable"
" with this Column object until its 'name' has "
- "been assigned.")
+ "been assigned."
+ )
try:
c = self._constructor(
- _as_truncated(name or self.name) if
- name_is_truncatable else (name or self.name),
+ _as_truncated(name or self.name)
+ if name_is_truncatable
+ else (name or self.name),
self.type,
key=key if key else name if name else self.key,
primary_key=self.primary_key,
nullable=self.nullable,
- _proxies=[self], *fk)
+ _proxies=[self],
+ *fk
+ )
except TypeError:
util.raise_from_cause(
TypeError(
@@ -1485,7 +1554,8 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
"Ensure the class includes a _constructor() "
"attribute or method which accepts the "
"standard Column constructor arguments, or "
- "references the Column class itself." % self.__class__)
+ "references the Column class itself." % self.__class__
+ )
)
c.table = selectable
@@ -1499,9 +1569,11 @@ class Column(DialectKWArgs, SchemaItem, ColumnClause):
def get_children(self, schema_visitor=False, **kwargs):
if schema_visitor:
- return [x for x in (self.default, self.onupdate)
- if x is not None] + \
- list(self.foreign_keys) + list(self.constraints)
+ return (
+ [x for x in (self.default, self.onupdate) if x is not None]
+ + list(self.foreign_keys)
+ + list(self.constraints)
+ )
else:
return ColumnClause.get_children(self, **kwargs)
@@ -1543,13 +1615,23 @@ class ForeignKey(DialectKWArgs, SchemaItem):
"""
- __visit_name__ = 'foreign_key'
-
- def __init__(self, column, _constraint=None, use_alter=False, name=None,
- onupdate=None, ondelete=None, deferrable=None,
- initially=None, link_to_name=False, match=None,
- info=None,
- **dialect_kw):
+ __visit_name__ = "foreign_key"
+
+ def __init__(
+ self,
+ column,
+ _constraint=None,
+ use_alter=False,
+ name=None,
+ onupdate=None,
+ ondelete=None,
+ deferrable=None,
+ initially=None,
+ link_to_name=False,
+ match=None,
+ info=None,
+ **dialect_kw
+ ):
r"""
Construct a column-level FOREIGN KEY.
@@ -1626,7 +1708,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
if isinstance(self._colspec, util.string_types):
self._table_column = None
else:
- if hasattr(self._colspec, '__clause_element__'):
+ if hasattr(self._colspec, "__clause_element__"):
self._table_column = self._colspec.__clause_element__()
else:
self._table_column = self._colspec
@@ -1634,9 +1716,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
if not isinstance(self._table_column, ColumnClause):
raise exc.ArgumentError(
"String, Column, or Column-bound argument "
- "expected, got %r" % self._table_column)
+ "expected, got %r" % self._table_column
+ )
elif not isinstance(
- self._table_column.table, (util.NoneType, TableClause)):
+ self._table_column.table, (util.NoneType, TableClause)
+ ):
raise exc.ArgumentError(
"ForeignKey received Column not bound "
"to a Table, got: %r" % self._table_column.table
@@ -1715,7 +1799,9 @@ class ForeignKey(DialectKWArgs, SchemaItem):
return "%s.%s" % (table_name, colname)
elif self._table_column is not None:
return "%s.%s" % (
- self._table_column.table.fullname, self._table_column.key)
+ self._table_column.table.fullname,
+ self._table_column.key,
+ )
else:
return self._colspec
@@ -1756,12 +1842,12 @@ class ForeignKey(DialectKWArgs, SchemaItem):
def _column_tokens(self):
"""parse a string-based _colspec into its component parts."""
- m = self._get_colspec().split('.')
+ m = self._get_colspec().split(".")
if m is None:
raise exc.ArgumentError(
- "Invalid foreign key column specification: %s" %
- self._colspec)
- if (len(m) == 1):
+ "Invalid foreign key column specification: %s" % self._colspec
+ )
+ if len(m) == 1:
tname = m.pop()
colname = None
else:
@@ -1777,8 +1863,8 @@ class ForeignKey(DialectKWArgs, SchemaItem):
# indirectly related -- Ticket #594. This assumes that '.'
# will never appear *within* any component of the FK.
- if (len(m) > 0):
- schema = '.'.join(m)
+ if len(m) > 0:
+ schema = ".".join(m)
else:
schema = None
return schema, tname, colname
@@ -1787,12 +1873,14 @@ class ForeignKey(DialectKWArgs, SchemaItem):
if self.parent is None:
raise exc.InvalidRequestError(
"this ForeignKey object does not yet have a "
- "parent Column associated with it.")
+ "parent Column associated with it."
+ )
elif self.parent.table is None:
raise exc.InvalidRequestError(
"this ForeignKey's parent column is not yet associated "
- "with a Table.")
+ "with a Table."
+ )
parenttable = self.parent.table
@@ -1817,7 +1905,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
return parenttable, tablekey, colname
def _link_to_col_by_colstring(self, parenttable, table, colname):
- if not hasattr(self.constraint, '_referred_table'):
+ if not hasattr(self.constraint, "_referred_table"):
self.constraint._referred_table = table
else:
assert self.constraint._referred_table is table
@@ -1843,9 +1931,11 @@ class ForeignKey(DialectKWArgs, SchemaItem):
raise exc.NoReferencedColumnError(
"Could not initialize target column "
"for ForeignKey '%s' on table '%s': "
- "table '%s' has no column named '%s'" %
- (self._colspec, parenttable.name, table.name, key),
- table.name, key)
+ "table '%s' has no column named '%s'"
+ % (self._colspec, parenttable.name, table.name, key),
+ table.name,
+ key,
+ )
self._set_target_column(_column)
@@ -1861,6 +1951,7 @@ class ForeignKey(DialectKWArgs, SchemaItem):
def set_type(fk):
if fk.parent.type._isnull:
fk.parent.type = column.type
+
self.parent._setup_on_memoized_fks(set_type)
self.column = column
@@ -1888,21 +1979,25 @@ class ForeignKey(DialectKWArgs, SchemaItem):
raise exc.NoReferencedTableError(
"Foreign key associated with column '%s' could not find "
"table '%s' with which to generate a "
- "foreign key to target column '%s'" %
- (self.parent, tablekey, colname),
- tablekey)
+ "foreign key to target column '%s'"
+ % (self.parent, tablekey, colname),
+ tablekey,
+ )
elif parenttable.key not in parenttable.metadata:
raise exc.InvalidRequestError(
"Table %s is no longer associated with its "
- "parent MetaData" % parenttable)
+ "parent MetaData" % parenttable
+ )
else:
raise exc.NoReferencedColumnError(
"Could not initialize target column for "
"ForeignKey '%s' on table '%s': "
- "table '%s' has no column named '%s'" % (
- self._colspec, parenttable.name, tablekey, colname),
- tablekey, colname)
- elif hasattr(self._colspec, '__clause_element__'):
+ "table '%s' has no column named '%s'"
+ % (self._colspec, parenttable.name, tablekey, colname),
+ tablekey,
+ colname,
+ )
+ elif hasattr(self._colspec, "__clause_element__"):
_column = self._colspec.__clause_element__()
return _column
else:
@@ -1912,7 +2007,8 @@ class ForeignKey(DialectKWArgs, SchemaItem):
def _set_parent(self, column):
if self.parent is not None and self.parent is not column:
raise exc.InvalidRequestError(
- "This ForeignKey already has a parent !")
+ "This ForeignKey already has a parent !"
+ )
self.parent = column
self.parent.foreign_keys.add(self)
self.parent._on_table_attach(self._set_table)
@@ -1935,9 +2031,14 @@ class ForeignKey(DialectKWArgs, SchemaItem):
# on the hosting Table when attached to the Table.
if self.constraint is None and isinstance(table, Table):
self.constraint = ForeignKeyConstraint(
- [], [], use_alter=self.use_alter, name=self.name,
- onupdate=self.onupdate, ondelete=self.ondelete,
- deferrable=self.deferrable, initially=self.initially,
+ [],
+ [],
+ use_alter=self.use_alter,
+ name=self.name,
+ onupdate=self.onupdate,
+ ondelete=self.ondelete,
+ deferrable=self.deferrable,
+ initially=self.initially,
match=self.match,
**self._unvalidated_dialect_kw
)
@@ -1953,13 +2054,12 @@ class ForeignKey(DialectKWArgs, SchemaItem):
if table_key in parenttable.metadata.tables:
table = parenttable.metadata.tables[table_key]
try:
- self._link_to_col_by_colstring(
- parenttable, table, colname)
+ self._link_to_col_by_colstring(parenttable, table, colname)
except exc.NoReferencedColumnError:
# this is OK, we'll try later
pass
parenttable.metadata._fk_memos[fk_key].append(self)
- elif hasattr(self._colspec, '__clause_element__'):
+ elif hasattr(self._colspec, "__clause_element__"):
_column = self._colspec.__clause_element__()
self._set_target_column(_column)
else:
@@ -1971,7 +2071,8 @@ class _NotAColumnExpr(object):
def _not_a_column_expr(self):
raise exc.InvalidRequestError(
"This %s cannot be used directly "
- "as a column expression." % self.__class__.__name__)
+ "as a column expression." % self.__class__.__name__
+ )
__clause_element__ = self_group = lambda self: self._not_a_column_expr()
_from_objects = property(lambda self: self._not_a_column_expr())
@@ -1980,7 +2081,7 @@ class _NotAColumnExpr(object):
class DefaultGenerator(_NotAColumnExpr, SchemaItem):
"""Base class for column *default* values."""
- __visit_name__ = 'default_generator'
+ __visit_name__ = "default_generator"
is_sequence = False
is_server_default = False
@@ -2007,7 +2108,7 @@ class DefaultGenerator(_NotAColumnExpr, SchemaItem):
@property
def bind(self):
"""Return the connectable associated with this default."""
- if getattr(self, 'column', None) is not None:
+ if getattr(self, "column", None) is not None:
return self.column.table.bind
else:
return None
@@ -2064,7 +2165,8 @@ class ColumnDefault(DefaultGenerator):
super(ColumnDefault, self).__init__(**kwargs)
if isinstance(arg, FetchedValue):
raise exc.ArgumentError(
- "ColumnDefault may not be a server-side default type.")
+ "ColumnDefault may not be a server-side default type."
+ )
if util.callable(arg):
arg = self._maybe_wrap_callable(arg)
self.arg = arg
@@ -2079,9 +2181,11 @@ class ColumnDefault(DefaultGenerator):
@util.memoized_property
def is_scalar(self):
- return not self.is_callable and \
- not self.is_clause_element and \
- not self.is_sequence
+ return (
+ not self.is_callable
+ and not self.is_clause_element
+ and not self.is_sequence
+ )
@util.memoized_property
@util.dependencies("sqlalchemy.sql.sqltypes")
@@ -2114,17 +2218,19 @@ class ColumnDefault(DefaultGenerator):
else:
raise exc.ArgumentError(
"ColumnDefault Python function takes zero or one "
- "positional arguments")
+ "positional arguments"
+ )
def _visit_name(self):
if self.for_update:
return "column_onupdate"
else:
return "column_default"
+
__visit_name__ = property(_visit_name)
def __repr__(self):
- return "ColumnDefault(%r)" % (self.arg, )
+ return "ColumnDefault(%r)" % (self.arg,)
class Sequence(DefaultGenerator):
@@ -2157,15 +2263,29 @@ class Sequence(DefaultGenerator):
"""
- __visit_name__ = 'sequence'
+ __visit_name__ = "sequence"
is_sequence = True
- def __init__(self, name, start=None, increment=None, minvalue=None,
- maxvalue=None, nominvalue=None, nomaxvalue=None, cycle=None,
- schema=None, cache=None, order=None, optional=False,
- quote=None, metadata=None, quote_schema=None,
- for_update=False):
+ def __init__(
+ self,
+ name,
+ start=None,
+ increment=None,
+ minvalue=None,
+ maxvalue=None,
+ nominvalue=None,
+ nomaxvalue=None,
+ cycle=None,
+ schema=None,
+ cache=None,
+ order=None,
+ optional=False,
+ quote=None,
+ metadata=None,
+ quote_schema=None,
+ for_update=False,
+ ):
"""Construct a :class:`.Sequence` object.
:param name: The name of the sequence.
@@ -2353,27 +2473,22 @@ class Sequence(DefaultGenerator):
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaGenerator,
- self,
- checkfirst=checkfirst)
+ bind._run_visitor(ddl.SchemaGenerator, self, checkfirst=checkfirst)
def drop(self, bind=None, checkfirst=True):
"""Drops this sequence from the database."""
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaDropper,
- self,
- checkfirst=checkfirst)
+ bind._run_visitor(ddl.SchemaDropper, self, checkfirst=checkfirst)
def _not_a_column_expr(self):
raise exc.InvalidRequestError(
"This %s cannot be used directly "
"as a column expression. Use func.next_value(sequence) "
"to produce a 'next value' function that's usable "
- "as a column element."
- % self.__class__.__name__)
-
+ "as a column element." % self.__class__.__name__
+ )
@inspection._self_inspects
@@ -2396,6 +2511,7 @@ class FetchedValue(_NotAColumnExpr, SchemaEventTarget):
:ref:`triggered_columns`
"""
+
is_server_default = True
reflected = False
has_argument = False
@@ -2412,7 +2528,7 @@ class FetchedValue(_NotAColumnExpr, SchemaEventTarget):
def _clone(self, for_update):
n = self.__class__.__new__(self.__class__)
n.__dict__.update(self.__dict__)
- n.__dict__.pop('column', None)
+ n.__dict__.pop("column", None)
n.for_update = for_update
return n
@@ -2452,16 +2568,15 @@ class DefaultClause(FetchedValue):
has_argument = True
def __init__(self, arg, for_update=False, _reflected=False):
- util.assert_arg_type(arg, (util.string_types[0],
- ClauseElement,
- TextClause), 'arg')
+ util.assert_arg_type(
+ arg, (util.string_types[0], ClauseElement, TextClause), "arg"
+ )
super(DefaultClause, self).__init__(for_update)
self.arg = arg
self.reflected = _reflected
def __repr__(self):
- return "DefaultClause(%r, for_update=%r)" % \
- (self.arg, self.for_update)
+ return "DefaultClause(%r, for_update=%r)" % (self.arg, self.for_update)
class PassiveDefault(DefaultClause):
@@ -2471,10 +2586,13 @@ class PassiveDefault(DefaultClause):
:class:`.PassiveDefault` is deprecated.
Use :class:`.DefaultClause`.
"""
- @util.deprecated("0.6",
- ":class:`.PassiveDefault` is deprecated. "
- "Use :class:`.DefaultClause`.",
- False)
+
+ @util.deprecated(
+ "0.6",
+ ":class:`.PassiveDefault` is deprecated. "
+ "Use :class:`.DefaultClause`.",
+ False,
+ )
def __init__(self, *arg, **kw):
DefaultClause.__init__(self, *arg, **kw)
@@ -2482,11 +2600,18 @@ class PassiveDefault(DefaultClause):
class Constraint(DialectKWArgs, SchemaItem):
"""A table-level SQL constraint."""
- __visit_name__ = 'constraint'
-
- def __init__(self, name=None, deferrable=None, initially=None,
- _create_rule=None, info=None, _type_bound=False,
- **dialect_kw):
+ __visit_name__ = "constraint"
+
+ def __init__(
+ self,
+ name=None,
+ deferrable=None,
+ initially=None,
+ _create_rule=None,
+ info=None,
+ _type_bound=False,
+ **dialect_kw
+ ):
r"""Create a SQL constraint.
:param name:
@@ -2548,7 +2673,8 @@ class Constraint(DialectKWArgs, SchemaItem):
pass
raise exc.InvalidRequestError(
"This constraint is not bound to a table. Did you "
- "mean to call table.append_constraint(constraint) ?")
+ "mean to call table.append_constraint(constraint) ?"
+ )
def _set_parent(self, parent):
self.parent = parent
@@ -2559,7 +2685,7 @@ class Constraint(DialectKWArgs, SchemaItem):
def _to_schema_column(element):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
element = element.__clause_element__()
if not isinstance(element, Column):
raise exc.ArgumentError("schema.Column object expected")
@@ -2567,9 +2693,9 @@ def _to_schema_column(element):
def _to_schema_column_or_string(element):
- if hasattr(element, '__clause_element__'):
+ if hasattr(element, "__clause_element__"):
element = element.__clause_element__()
- if not isinstance(element, util.string_types + (ColumnElement, )):
+ if not isinstance(element, util.string_types + (ColumnElement,)):
msg = "Element %r is not a string name or column element"
raise exc.ArgumentError(msg % element)
return element
@@ -2588,11 +2714,12 @@ class ColumnCollectionMixin(object):
_allow_multiple_tables = False
def __init__(self, *columns, **kw):
- _autoattach = kw.pop('_autoattach', True)
- self._column_flag = kw.pop('_column_flag', False)
+ _autoattach = kw.pop("_autoattach", True)
+ self._column_flag = kw.pop("_column_flag", False)
self.columns = ColumnCollection()
- self._pending_colargs = [_to_schema_column_or_string(c)
- for c in columns]
+ self._pending_colargs = [
+ _to_schema_column_or_string(c) for c in columns
+ ]
if _autoattach and self._pending_colargs:
self._check_attach()
@@ -2601,7 +2728,7 @@ class ColumnCollectionMixin(object):
for expr in expressions:
strname = None
column = None
- if hasattr(expr, '__clause_element__'):
+ if hasattr(expr, "__clause_element__"):
expr = expr.__clause_element__()
if not isinstance(expr, (ColumnElement, TextClause)):
@@ -2609,21 +2736,16 @@ class ColumnCollectionMixin(object):
strname = expr
else:
cols = []
- visitors.traverse(expr, {}, {'column': cols.append})
+ visitors.traverse(expr, {}, {"column": cols.append})
if cols:
column = cols[0]
add_element = column if column is not None else strname
yield expr, column, strname, add_element
def _check_attach(self, evt=False):
- col_objs = [
- c for c in self._pending_colargs
- if isinstance(c, Column)
- ]
+ col_objs = [c for c in self._pending_colargs if isinstance(c, Column)]
- cols_w_table = [
- c for c in col_objs if isinstance(c.table, Table)
- ]
+ cols_w_table = [c for c in col_objs if isinstance(c.table, Table)]
cols_wo_table = set(col_objs).difference(cols_w_table)
@@ -2636,6 +2758,7 @@ class ColumnCollectionMixin(object):
# columns are specified as strings.
has_string_cols = set(self._pending_colargs).difference(col_objs)
if not has_string_cols:
+
def _col_attached(column, table):
# this isinstance() corresponds with the
# isinstance() above; only want to count Table-bound
@@ -2644,6 +2767,7 @@ class ColumnCollectionMixin(object):
cols_wo_table.discard(column)
if not cols_wo_table:
self._check_attach(evt=True)
+
self._cols_wo_table = cols_wo_table
for col in cols_wo_table:
col._on_table_attach(_col_attached)
@@ -2659,9 +2783,11 @@ class ColumnCollectionMixin(object):
others = [c for c in columns[1:] if c.table is not table]
if others:
raise exc.ArgumentError(
- "Column(s) %s are not part of table '%s'." %
- (", ".join("'%s'" % c for c in others),
- table.description)
+ "Column(s) %s are not part of table '%s'."
+ % (
+ ", ".join("'%s'" % c for c in others),
+ table.description,
+ )
)
def _set_parent(self, table):
@@ -2694,11 +2820,12 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
arguments are propagated to the :class:`.Constraint` superclass.
"""
- _autoattach = kw.pop('_autoattach', True)
- _column_flag = kw.pop('_column_flag', False)
+ _autoattach = kw.pop("_autoattach", True)
+ _column_flag = kw.pop("_column_flag", False)
Constraint.__init__(self, **kw)
ColumnCollectionMixin.__init__(
- self, *columns, _autoattach=_autoattach, _column_flag=_column_flag)
+ self, *columns, _autoattach=_autoattach, _column_flag=_column_flag
+ )
columns = None
"""A :class:`.ColumnCollection` representing the set of columns
@@ -2714,8 +2841,12 @@ class ColumnCollectionConstraint(ColumnCollectionMixin, Constraint):
return x in self.columns
def copy(self, **kw):
- c = self.__class__(name=self.name, deferrable=self.deferrable,
- initially=self.initially, *self.columns.keys())
+ c = self.__class__(
+ name=self.name,
+ deferrable=self.deferrable,
+ initially=self.initially,
+ *self.columns.keys()
+ )
return self._schema_item_copy(c)
def contains_column(self, col):
@@ -2747,9 +2878,19 @@ class CheckConstraint(ColumnCollectionConstraint):
_allow_multiple_tables = True
- def __init__(self, sqltext, name=None, deferrable=None,
- initially=None, table=None, info=None, _create_rule=None,
- _autoattach=True, _type_bound=False, **kw):
+ def __init__(
+ self,
+ sqltext,
+ name=None,
+ deferrable=None,
+ initially=None,
+ table=None,
+ info=None,
+ _create_rule=None,
+ _autoattach=True,
+ _type_bound=False,
+ **kw
+ ):
r"""Construct a CHECK constraint.
:param sqltext:
@@ -2781,14 +2922,19 @@ class CheckConstraint(ColumnCollectionConstraint):
self.sqltext = _literal_as_text(sqltext, warn=False)
columns = []
- visitors.traverse(self.sqltext, {}, {'column': columns.append})
-
- super(CheckConstraint, self).\
- __init__(
- name=name, deferrable=deferrable,
- initially=initially, _create_rule=_create_rule, info=info,
- _type_bound=_type_bound, _autoattach=_autoattach,
- *columns, **kw)
+ visitors.traverse(self.sqltext, {}, {"column": columns.append})
+
+ super(CheckConstraint, self).__init__(
+ name=name,
+ deferrable=deferrable,
+ initially=initially,
+ _create_rule=_create_rule,
+ info=info,
+ _type_bound=_type_bound,
+ _autoattach=_autoattach,
+ *columns,
+ **kw
+ )
if table is not None:
self._set_parent_with_dispatch(table)
@@ -2797,22 +2943,24 @@ class CheckConstraint(ColumnCollectionConstraint):
return "check_constraint"
else:
return "column_check_constraint"
+
__visit_name__ = property(__visit_name__)
def copy(self, target_table=None, **kw):
if target_table is not None:
- sqltext = _copy_expression(
- self.sqltext, self.table, target_table)
+ sqltext = _copy_expression(self.sqltext, self.table, target_table)
else:
sqltext = self.sqltext
- c = CheckConstraint(sqltext,
- name=self.name,
- initially=self.initially,
- deferrable=self.deferrable,
- _create_rule=self._create_rule,
- table=target_table,
- _autoattach=False,
- _type_bound=self._type_bound)
+ c = CheckConstraint(
+ sqltext,
+ name=self.name,
+ initially=self.initially,
+ deferrable=self.deferrable,
+ _create_rule=self._create_rule,
+ table=target_table,
+ _autoattach=False,
+ _type_bound=self._type_bound,
+ )
return self._schema_item_copy(c)
@@ -2828,12 +2976,25 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
Examples of foreign key configuration are in :ref:`metadata_foreignkeys`.
"""
- __visit_name__ = 'foreign_key_constraint'
- def __init__(self, columns, refcolumns, name=None, onupdate=None,
- ondelete=None, deferrable=None, initially=None,
- use_alter=False, link_to_name=False, match=None,
- table=None, info=None, **dialect_kw):
+ __visit_name__ = "foreign_key_constraint"
+
+ def __init__(
+ self,
+ columns,
+ refcolumns,
+ name=None,
+ onupdate=None,
+ ondelete=None,
+ deferrable=None,
+ initially=None,
+ use_alter=False,
+ link_to_name=False,
+ match=None,
+ table=None,
+ info=None,
+ **dialect_kw
+ ):
r"""Construct a composite-capable FOREIGN KEY.
:param columns: A sequence of local column names. The named columns
@@ -2905,8 +3066,13 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
"""
Constraint.__init__(
- self, name=name, deferrable=deferrable, initially=initially,
- info=info, **dialect_kw)
+ self,
+ name=name,
+ deferrable=deferrable,
+ initially=initially,
+ info=info,
+ **dialect_kw
+ )
self.onupdate = onupdate
self.ondelete = ondelete
self.link_to_name = link_to_name
@@ -2927,7 +3093,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
raise exc.ArgumentError(
"ForeignKeyConstraint number "
"of constrained columns must match the number of "
- "referenced columns.")
+ "referenced columns."
+ )
# standalone ForeignKeyConstraint - create
# associated ForeignKey objects which will be applied to hosted
@@ -2946,7 +3113,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
deferrable=self.deferrable,
initially=self.initially,
**self.dialect_kwargs
- ) for refcol in refcolumns
+ )
+ for refcol in refcolumns
]
ColumnCollectionMixin.__init__(self, *columns)
@@ -2978,9 +3146,7 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
@property
def _elements(self):
# legacy - provide a dictionary view of (column_key, fk)
- return util.OrderedDict(
- zip(self.column_keys, self.elements)
- )
+ return util.OrderedDict(zip(self.column_keys, self.elements))
@property
def _referred_schema(self):
@@ -3004,18 +3170,14 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
return self.elements[0].column.table
def _validate_dest_table(self, table):
- table_keys = set([elem._table_key()
- for elem in self.elements])
+ table_keys = set([elem._table_key() for elem in self.elements])
if None not in table_keys and len(table_keys) > 1:
elem0, elem1 = sorted(table_keys)[0:2]
raise exc.ArgumentError(
- 'ForeignKeyConstraint on %s(%s) refers to '
- 'multiple remote tables: %s and %s' % (
- table.fullname,
- self._col_description,
- elem0,
- elem1
- ))
+ "ForeignKeyConstraint on %s(%s) refers to "
+ "multiple remote tables: %s and %s"
+ % (table.fullname, self._col_description, elem0, elem1)
+ )
@property
def column_keys(self):
@@ -3034,8 +3196,8 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
return self.columns.keys()
else:
return [
- col.key if isinstance(col, ColumnElement)
- else str(col) for col in self._pending_colargs
+ col.key if isinstance(col, ColumnElement) else str(col)
+ for col in self._pending_colargs
]
@property
@@ -3051,11 +3213,11 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
raise exc.ArgumentError(
"Can't create ForeignKeyConstraint "
"on table '%s': no column "
- "named '%s' is present." % (table.description, ke.args[0]))
+ "named '%s' is present." % (table.description, ke.args[0])
+ )
for col, fk in zip(self.columns, self.elements):
- if not hasattr(fk, 'parent') or \
- fk.parent is not col:
+ if not hasattr(fk, "parent") or fk.parent is not col:
fk._set_parent_with_dispatch(col)
self._validate_dest_table(table)
@@ -3063,13 +3225,16 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
def copy(self, schema=None, target_table=None, **kw):
fkc = ForeignKeyConstraint(
[x.parent.key for x in self.elements],
- [x._get_colspec(
- schema=schema,
- table_name=target_table.name
- if target_table is not None
- and x._table_key() == x.parent.table.key
- else None)
- for x in self.elements],
+ [
+ x._get_colspec(
+ schema=schema,
+ table_name=target_table.name
+ if target_table is not None
+ and x._table_key() == x.parent.table.key
+ else None,
+ )
+ for x in self.elements
+ ],
name=self.name,
onupdate=self.onupdate,
ondelete=self.ondelete,
@@ -3077,11 +3242,9 @@ class ForeignKeyConstraint(ColumnCollectionConstraint):
deferrable=self.deferrable,
initially=self.initially,
link_to_name=self.link_to_name,
- match=self.match
+ match=self.match,
)
- for self_fk, other_fk in zip(
- self.elements,
- fkc.elements):
+ for self_fk, other_fk in zip(self.elements, fkc.elements):
self_fk._schema_item_copy(other_fk)
return self._schema_item_copy(fkc)
@@ -3160,10 +3323,10 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
"""
- __visit_name__ = 'primary_key_constraint'
+ __visit_name__ = "primary_key_constraint"
def __init__(self, *columns, **kw):
- self._implicit_generated = kw.pop('_implicit_generated', False)
+ self._implicit_generated = kw.pop("_implicit_generated", False)
super(PrimaryKeyConstraint, self).__init__(*columns, **kw)
def _set_parent(self, table):
@@ -3175,18 +3338,21 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
table.constraints.add(self)
table_pks = [c for c in table.c if c.primary_key]
- if self.columns and table_pks and \
- set(table_pks) != set(self.columns.values()):
+ if (
+ self.columns
+ and table_pks
+ and set(table_pks) != set(self.columns.values())
+ ):
util.warn(
"Table '%s' specifies columns %s as primary_key=True, "
"not matching locally specified columns %s; setting the "
"current primary key columns to %s. This warning "
- "may become an exception in a future release" %
- (
+ "may become an exception in a future release"
+ % (
table.name,
", ".join("'%s'" % c.name for c in table_pks),
", ".join("'%s'" % c.name for c in self.columns),
- ", ".join("'%s'" % c.name for c in self.columns)
+ ", ".join("'%s'" % c.name for c in self.columns),
)
)
table_pks[:] = []
@@ -3241,28 +3407,28 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
@util.memoized_property
def _autoincrement_column(self):
-
def _validate_autoinc(col, autoinc_true):
if col.type._type_affinity is None or not issubclass(
- col.type._type_affinity,
- type_api.INTEGERTYPE._type_affinity):
+ col.type._type_affinity, type_api.INTEGERTYPE._type_affinity
+ ):
if autoinc_true:
raise exc.ArgumentError(
"Column type %s on column '%s' is not "
- "compatible with autoincrement=True" % (
- col.type,
- col
- ))
+ "compatible with autoincrement=True" % (col.type, col)
+ )
else:
return False
- elif not isinstance(col.default, (type(None), Sequence)) and \
- not autoinc_true:
- return False
+ elif (
+ not isinstance(col.default, (type(None), Sequence))
+ and not autoinc_true
+ ):
+ return False
elif col.server_default is not None and not autoinc_true:
return False
- elif (
- col.foreign_keys and col.autoincrement
- not in (True, 'ignore_fk')):
+ elif col.foreign_keys and col.autoincrement not in (
+ True,
+ "ignore_fk",
+ ):
return False
return True
@@ -3272,10 +3438,10 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
if col.autoincrement is True:
_validate_autoinc(col, True)
return col
- elif (
- col.autoincrement in ('auto', 'ignore_fk') and
- _validate_autoinc(col, False)
- ):
+ elif col.autoincrement in (
+ "auto",
+ "ignore_fk",
+ ) and _validate_autoinc(col, False):
return col
else:
@@ -3286,8 +3452,8 @@ class PrimaryKeyConstraint(ColumnCollectionConstraint):
if autoinc is not None:
raise exc.ArgumentError(
"Only one Column may be marked "
- "autoincrement=True, found both %s and %s." %
- (col.name, autoinc.name)
+ "autoincrement=True, found both %s and %s."
+ % (col.name, autoinc.name)
)
else:
autoinc = col
@@ -3304,7 +3470,7 @@ class UniqueConstraint(ColumnCollectionConstraint):
UniqueConstraint.
"""
- __visit_name__ = 'unique_constraint'
+ __visit_name__ = "unique_constraint"
class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
@@ -3382,7 +3548,7 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
"""
- __visit_name__ = 'index'
+ __visit_name__ = "index"
def __init__(self, name, *expressions, **kw):
r"""Construct an index object.
@@ -3420,30 +3586,35 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
columns = []
processed_expressions = []
- for expr, column, strname, add_element in self.\
- _extract_col_expression_collection(expressions):
+ for (
+ expr,
+ column,
+ strname,
+ add_element,
+ ) in self._extract_col_expression_collection(expressions):
if add_element is not None:
columns.append(add_element)
processed_expressions.append(expr)
self.expressions = processed_expressions
self.name = quoted_name(name, kw.pop("quote", None))
- self.unique = kw.pop('unique', False)
- _column_flag = kw.pop('_column_flag', False)
- if 'info' in kw:
- self.info = kw.pop('info')
+ self.unique = kw.pop("unique", False)
+ _column_flag = kw.pop("_column_flag", False)
+ if "info" in kw:
+ self.info = kw.pop("info")
# TODO: consider "table" argument being public, but for
# the purpose of the fix here, it starts as private.
- if '_table' in kw:
- table = kw.pop('_table')
+ if "_table" in kw:
+ table = kw.pop("_table")
self._validate_dialect_kwargs(kw)
# will call _set_parent() if table-bound column
# objects are present
ColumnCollectionMixin.__init__(
- self, *columns, _column_flag=_column_flag)
+ self, *columns, _column_flag=_column_flag
+ )
if table is not None:
self._set_parent(table)
@@ -3454,20 +3625,17 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
if self.table is not None and table is not self.table:
raise exc.ArgumentError(
"Index '%s' is against table '%s', and "
- "cannot be associated with table '%s'." % (
- self.name,
- self.table.description,
- table.description
- )
+ "cannot be associated with table '%s'."
+ % (self.name, self.table.description, table.description)
)
self.table = table
table.indexes.add(self)
self.expressions = [
- expr if isinstance(expr, ClauseElement)
- else colexpr
- for expr, colexpr in util.zip_longest(self.expressions,
- self.columns)
+ expr if isinstance(expr, ClauseElement) else colexpr
+ for expr, colexpr in util.zip_longest(
+ self.expressions, self.columns
+ )
]
@property
@@ -3506,17 +3674,16 @@ class Index(DialectKWArgs, ColumnCollectionMixin, SchemaItem):
bind._run_visitor(ddl.SchemaDropper, self)
def __repr__(self):
- return 'Index(%s)' % (
+ return "Index(%s)" % (
", ".join(
- [repr(self.name)] +
- [repr(e) for e in self.expressions] +
- (self.unique and ["unique=True"] or [])
- ))
+ [repr(self.name)]
+ + [repr(e) for e in self.expressions]
+ + (self.unique and ["unique=True"] or [])
+ )
+ )
-DEFAULT_NAMING_CONVENTION = util.immutabledict({
- "ix": 'ix_%(column_0_label)s'
-})
+DEFAULT_NAMING_CONVENTION = util.immutabledict({"ix": "ix_%(column_0_label)s"})
class MetaData(SchemaItem):
@@ -3542,13 +3709,17 @@ class MetaData(SchemaItem):
"""
- __visit_name__ = 'metadata'
-
- def __init__(self, bind=None, reflect=False, schema=None,
- quote_schema=None,
- naming_convention=DEFAULT_NAMING_CONVENTION,
- info=None
- ):
+ __visit_name__ = "metadata"
+
+ def __init__(
+ self,
+ bind=None,
+ reflect=False,
+ schema=None,
+ quote_schema=None,
+ naming_convention=DEFAULT_NAMING_CONVENTION,
+ info=None,
+ ):
"""Create a new MetaData object.
:param bind:
@@ -3712,12 +3883,15 @@ class MetaData(SchemaItem):
self.bind = bind
if reflect:
- util.warn_deprecated("reflect=True is deprecate; please "
- "use the reflect() method.")
+ util.warn_deprecated(
+ "reflect=True is deprecate; please "
+ "use the reflect() method."
+ )
if not bind:
raise exc.ArgumentError(
"A bind must be supplied in conjunction "
- "with reflect=True")
+ "with reflect=True"
+ )
self.reflect()
tables = None
@@ -3735,7 +3909,7 @@ class MetaData(SchemaItem):
"""
def __repr__(self):
- return 'MetaData(bind=%r)' % self.bind
+ return "MetaData(bind=%r)" % self.bind
def __contains__(self, table_or_key):
if not isinstance(table_or_key, util.string_types):
@@ -3755,27 +3929,32 @@ class MetaData(SchemaItem):
for fk in removed.foreign_keys:
fk._remove_from_metadata(self)
if self._schemas:
- self._schemas = set([t.schema
- for t in self.tables.values()
- if t.schema is not None])
+ self._schemas = set(
+ [
+ t.schema
+ for t in self.tables.values()
+ if t.schema is not None
+ ]
+ )
def __getstate__(self):
- return {'tables': self.tables,
- 'schema': self.schema,
- 'schemas': self._schemas,
- 'sequences': self._sequences,
- 'fk_memos': self._fk_memos,
- 'naming_convention': self.naming_convention
- }
+ return {
+ "tables": self.tables,
+ "schema": self.schema,
+ "schemas": self._schemas,
+ "sequences": self._sequences,
+ "fk_memos": self._fk_memos,
+ "naming_convention": self.naming_convention,
+ }
def __setstate__(self, state):
- self.tables = state['tables']
- self.schema = state['schema']
- self.naming_convention = state['naming_convention']
+ self.tables = state["tables"]
+ self.schema = state["schema"]
+ self.naming_convention = state["naming_convention"]
self._bind = None
- self._sequences = state['sequences']
- self._schemas = state['schemas']
- self._fk_memos = state['fk_memos']
+ self._sequences = state["sequences"]
+ self._schemas = state["schemas"]
+ self._fk_memos = state["fk_memos"]
def is_bound(self):
"""True if this MetaData is bound to an Engine or Connection."""
@@ -3805,10 +3984,11 @@ class MetaData(SchemaItem):
def _bind_to(self, url, bind):
"""Bind this MetaData to an Engine, Connection, string or URL."""
- if isinstance(bind, util.string_types + (url.URL, )):
+ if isinstance(bind, util.string_types + (url.URL,)):
self._bind = sqlalchemy.create_engine(bind)
else:
self._bind = bind
+
bind = property(bind, _bind_to)
def clear(self):
@@ -3858,12 +4038,20 @@ class MetaData(SchemaItem):
"""
- return ddl.sort_tables(sorted(self.tables.values(), key=lambda t: t.key))
+ return ddl.sort_tables(
+ sorted(self.tables.values(), key=lambda t: t.key)
+ )
- def reflect(self, bind=None, schema=None, views=False, only=None,
- extend_existing=False,
- autoload_replace=True,
- **dialect_kwargs):
+ def reflect(
+ self,
+ bind=None,
+ schema=None,
+ views=False,
+ only=None,
+ extend_existing=False,
+ autoload_replace=True,
+ **dialect_kwargs
+ ):
r"""Load all available table definitions from the database.
Automatically creates ``Table`` entries in this ``MetaData`` for any
@@ -3926,11 +4114,11 @@ class MetaData(SchemaItem):
with bind.connect() as conn:
reflect_opts = {
- 'autoload': True,
- 'autoload_with': conn,
- 'extend_existing': extend_existing,
- 'autoload_replace': autoload_replace,
- '_extend_on': set()
+ "autoload": True,
+ "autoload_with": conn,
+ "extend_existing": extend_existing,
+ "autoload_replace": autoload_replace,
+ "_extend_on": set(),
}
reflect_opts.update(dialect_kwargs)
@@ -3939,42 +4127,49 @@ class MetaData(SchemaItem):
schema = self.schema
if schema is not None:
- reflect_opts['schema'] = schema
+ reflect_opts["schema"] = schema
available = util.OrderedSet(
- bind.engine.table_names(schema, connection=conn))
+ bind.engine.table_names(schema, connection=conn)
+ )
if views:
- available.update(
- bind.dialect.get_view_names(conn, schema)
- )
+ available.update(bind.dialect.get_view_names(conn, schema))
if schema is not None:
- available_w_schema = util.OrderedSet(["%s.%s" % (schema, name)
- for name in available])
+ available_w_schema = util.OrderedSet(
+ ["%s.%s" % (schema, name) for name in available]
+ )
else:
available_w_schema = available
current = set(self.tables)
if only is None:
- load = [name for name, schname in
- zip(available, available_w_schema)
- if extend_existing or schname not in current]
+ load = [
+ name
+ for name, schname in zip(available, available_w_schema)
+ if extend_existing or schname not in current
+ ]
elif util.callable(only):
- load = [name for name, schname in
- zip(available, available_w_schema)
- if (extend_existing or schname not in current)
- and only(name, self)]
+ load = [
+ name
+ for name, schname in zip(available, available_w_schema)
+ if (extend_existing or schname not in current)
+ and only(name, self)
+ ]
else:
missing = [name for name in only if name not in available]
if missing:
- s = schema and (" schema '%s'" % schema) or ''
+ s = schema and (" schema '%s'" % schema) or ""
raise exc.InvalidRequestError(
- 'Could not reflect: requested table(s) not available '
- 'in %r%s: (%s)' %
- (bind.engine, s, ', '.join(missing)))
- load = [name for name in only if extend_existing or
- name not in current]
+ "Could not reflect: requested table(s) not available "
+ "in %r%s: (%s)" % (bind.engine, s, ", ".join(missing))
+ )
+ load = [
+ name
+ for name in only
+ if extend_existing or name not in current
+ ]
for name in load:
try:
@@ -3989,11 +4184,12 @@ class MetaData(SchemaItem):
See :class:`.DDLEvents`.
"""
+
def adapt_listener(target, connection, **kw):
- tables = kw['tables']
+ tables = kw["tables"]
listener(event, target, connection, tables=tables)
- event.listen(self, "" + event_name.replace('-', '_'), adapt_listener)
+ event.listen(self, "" + event_name.replace("-", "_"), adapt_listener)
def create_all(self, bind=None, tables=None, checkfirst=True):
"""Create all tables stored in this metadata.
@@ -4017,10 +4213,9 @@ class MetaData(SchemaItem):
"""
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaGenerator,
- self,
- checkfirst=checkfirst,
- tables=tables)
+ bind._run_visitor(
+ ddl.SchemaGenerator, self, checkfirst=checkfirst, tables=tables
+ )
def drop_all(self, bind=None, tables=None, checkfirst=True):
"""Drop all tables stored in this metadata.
@@ -4044,10 +4239,9 @@ class MetaData(SchemaItem):
"""
if bind is None:
bind = _bind_or_error(self)
- bind._run_visitor(ddl.SchemaDropper,
- self,
- checkfirst=checkfirst,
- tables=tables)
+ bind._run_visitor(
+ ddl.SchemaDropper, self, checkfirst=checkfirst, tables=tables
+ )
class ThreadLocalMetaData(MetaData):
@@ -4064,7 +4258,7 @@ class ThreadLocalMetaData(MetaData):
"""
- __visit_name__ = 'metadata'
+ __visit_name__ = "metadata"
def __init__(self):
"""Construct a ThreadLocalMetaData."""
@@ -4080,13 +4274,13 @@ class ThreadLocalMetaData(MetaData):
string or URL to automatically create a basic Engine for this bind
with ``create_engine()``."""
- return getattr(self.context, '_engine', None)
+ return getattr(self.context, "_engine", None)
@util.dependencies("sqlalchemy.engine.url")
def _bind_to(self, url, bind):
"""Bind to a Connectable in the caller's thread."""
- if isinstance(bind, util.string_types + (url.URL, )):
+ if isinstance(bind, util.string_types + (url.URL,)):
try:
self.context._engine = self.__engines[bind]
except KeyError:
@@ -4104,14 +4298,16 @@ class ThreadLocalMetaData(MetaData):
def is_bound(self):
"""True if there is a bind for this thread."""
- return (hasattr(self.context, '_engine') and
- self.context._engine is not None)
+ return (
+ hasattr(self.context, "_engine")
+ and self.context._engine is not None
+ )
def dispose(self):
"""Dispose all bound engines, in all thread contexts."""
for e in self.__engines.values():
- if hasattr(e, 'dispose'):
+ if hasattr(e, "dispose"):
e.dispose()
@@ -4128,22 +4324,25 @@ class _SchemaTranslateMap(object):
"""
- __slots__ = 'map_', '__call__', 'hash_key', 'is_default'
+
+ __slots__ = "map_", "__call__", "hash_key", "is_default"
_default_schema_getter = operator.attrgetter("schema")
def __init__(self, map_):
self.map_ = map_
if map_ is not None:
+
def schema_for_object(obj):
effective_schema = self._default_schema_getter(obj)
effective_schema = obj._translate_schema(
- effective_schema, map_)
+ effective_schema, map_
+ )
return effective_schema
+
self.__call__ = schema_for_object
self.hash_key = ";".join(
- "%s=%s" % (k, map_[k])
- for k in sorted(map_, key=str)
+ "%s=%s" % (k, map_[k]) for k in sorted(map_, key=str)
)
self.is_default = False
else:
@@ -4160,6 +4359,6 @@ class _SchemaTranslateMap(object):
else:
return _SchemaTranslateMap(map_)
+
_default_schema_map = _SchemaTranslateMap(None)
_schema_getter = _SchemaTranslateMap._schema_getter
-
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index f64f152c4..1f1800514 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -10,15 +10,39 @@ SQL tables and derived rowsets.
"""
-from .elements import ClauseElement, TextClause, ClauseList, \
- and_, Grouping, UnaryExpression, literal_column, BindParameter
-from .elements import _clone, \
- _literal_as_text, _interpret_as_column_or_from, _expand_cloned,\
- _select_iterables, _anonymous_label, _clause_element_as_expr,\
- _cloned_intersection, _cloned_difference, True_, \
- _literal_as_label_reference, _literal_and_labels_as_label_reference
-from .base import Immutable, Executable, _generative, \
- ColumnCollection, ColumnSet, _from_objects, Generative
+from .elements import (
+ ClauseElement,
+ TextClause,
+ ClauseList,
+ and_,
+ Grouping,
+ UnaryExpression,
+ literal_column,
+ BindParameter,
+)
+from .elements import (
+ _clone,
+ _literal_as_text,
+ _interpret_as_column_or_from,
+ _expand_cloned,
+ _select_iterables,
+ _anonymous_label,
+ _clause_element_as_expr,
+ _cloned_intersection,
+ _cloned_difference,
+ True_,
+ _literal_as_label_reference,
+ _literal_and_labels_as_label_reference,
+)
+from .base import (
+ Immutable,
+ Executable,
+ _generative,
+ ColumnCollection,
+ ColumnSet,
+ _from_objects,
+ Generative,
+)
from . import type_api
from .. import inspection
from .. import util
@@ -40,7 +64,8 @@ def _interpret_as_from(element):
"Textual SQL FROM expression %(expr)r should be "
"explicitly declared as text(%(expr)r), "
"or use table(%(expr)r) for more specificity",
- {"expr": util.ellipses_string(element)})
+ {"expr": util.ellipses_string(element)},
+ )
return TextClause(util.text_type(element))
try:
@@ -73,7 +98,7 @@ def _offset_or_limit_clause(element, name=None, type_=None):
"""
if element is None:
return None
- elif hasattr(element, '__clause_element__'):
+ elif hasattr(element, "__clause_element__"):
return element.__clause_element__()
elif isinstance(element, Visitable):
return element
@@ -97,7 +122,8 @@ def _offset_or_limit_clause_asint(clause, attrname):
except AttributeError:
raise exc.CompileError(
"This SELECT structure does not use a simple "
- "integer value for %s" % attrname)
+ "integer value for %s" % attrname
+ )
else:
return util.asint(value)
@@ -225,12 +251,14 @@ def tablesample(selectable, sampling, name=None, seed=None):
"""
return _interpret_as_from(selectable).tablesample(
- sampling, name=name, seed=seed)
+ sampling, name=name, seed=seed
+ )
class Selectable(ClauseElement):
"""mark a class as being selectable"""
- __visit_name__ = 'selectable'
+
+ __visit_name__ = "selectable"
is_selectable = True
@@ -265,15 +293,17 @@ class HasPrefixes(object):
limit rendering of this prefix to only that dialect.
"""
- dialect = kw.pop('dialect', None)
+ dialect = kw.pop("dialect", None)
if kw:
- raise exc.ArgumentError("Unsupported argument(s): %s" %
- ",".join(kw))
+ raise exc.ArgumentError(
+ "Unsupported argument(s): %s" % ",".join(kw)
+ )
self._setup_prefixes(expr, dialect)
def _setup_prefixes(self, prefixes, dialect=None):
self._prefixes = self._prefixes + tuple(
- [(_literal_as_text(p, warn=False), dialect) for p in prefixes])
+ [(_literal_as_text(p, warn=False), dialect) for p in prefixes]
+ )
class HasSuffixes(object):
@@ -301,15 +331,17 @@ class HasSuffixes(object):
limit rendering of this suffix to only that dialect.
"""
- dialect = kw.pop('dialect', None)
+ dialect = kw.pop("dialect", None)
if kw:
- raise exc.ArgumentError("Unsupported argument(s): %s" %
- ",".join(kw))
+ raise exc.ArgumentError(
+ "Unsupported argument(s): %s" % ",".join(kw)
+ )
self._setup_suffixes(expr, dialect)
def _setup_suffixes(self, suffixes, dialect=None):
self._suffixes = self._suffixes + tuple(
- [(_literal_as_text(p, warn=False), dialect) for p in suffixes])
+ [(_literal_as_text(p, warn=False), dialect) for p in suffixes]
+ )
class FromClause(Selectable):
@@ -330,7 +362,8 @@ class FromClause(Selectable):
"""
- __visit_name__ = 'fromclause'
+
+ __visit_name__ = "fromclause"
named_with_column = False
_hide_froms = []
@@ -359,13 +392,14 @@ class FromClause(Selectable):
_memoized_property = util.group_expirable_memoized_property(["_columns"])
@util.deprecated(
- '1.1',
+ "1.1",
message="``FromClause.count()`` is deprecated. Counting "
"rows requires that the correct column expression and "
"accommodations for joins, DISTINCT, etc. must be made, "
"otherwise results may not be what's expected. "
"Please use an appropriate ``func.count()`` expression "
- "directly.")
+ "directly.",
+ )
@util.dependencies("sqlalchemy.sql.functions")
def count(self, functions, whereclause=None, **params):
"""return a SELECT COUNT generated against this
@@ -392,10 +426,11 @@ class FromClause(Selectable):
else:
col = list(self.columns)[0]
return Select(
- [functions.func.count(col).label('tbl_row_count')],
+ [functions.func.count(col).label("tbl_row_count")],
whereclause,
from_obj=[self],
- **params)
+ **params
+ )
def select(self, whereclause=None, **params):
"""return a SELECT of this :class:`.FromClause`.
@@ -603,8 +638,9 @@ class FromClause(Selectable):
def embedded(expanded_proxy_set, target_set):
for t in target_set.difference(expanded_proxy_set):
- if not set(_expand_cloned([t])
- ).intersection(expanded_proxy_set):
+ if not set(_expand_cloned([t])).intersection(
+ expanded_proxy_set
+ ):
return False
return True
@@ -617,8 +653,10 @@ class FromClause(Selectable):
for c in cols:
expanded_proxy_set = set(_expand_cloned(c.proxy_set))
i = target_set.intersection(expanded_proxy_set)
- if i and (not require_embedded
- or embedded(expanded_proxy_set, target_set)):
+ if i and (
+ not require_embedded
+ or embedded(expanded_proxy_set, target_set)
+ ):
if col is None:
# no corresponding column yet, pick this one.
@@ -646,12 +684,20 @@ class FromClause(Selectable):
col_distance = util.reduce(
operator.add,
- [sc._annotations.get('weight', 1) for sc in
- col.proxy_set if sc.shares_lineage(column)])
+ [
+ sc._annotations.get("weight", 1)
+ for sc in col.proxy_set
+ if sc.shares_lineage(column)
+ ],
+ )
c_distance = util.reduce(
operator.add,
- [sc._annotations.get('weight', 1) for sc in
- c.proxy_set if sc.shares_lineage(column)])
+ [
+ sc._annotations.get("weight", 1)
+ for sc in c.proxy_set
+ if sc.shares_lineage(column)
+ ],
+ )
if c_distance < col_distance:
col, intersect = c, i
return col
@@ -663,7 +709,7 @@ class FromClause(Selectable):
Used primarily for error message formatting.
"""
- return getattr(self, 'name', self.__class__.__name__ + " object")
+ return getattr(self, "name", self.__class__.__name__ + " object")
def _reset_exported(self):
"""delete memoized collections when a FromClause is cloned."""
@@ -683,7 +729,7 @@ class FromClause(Selectable):
"""
- if '_columns' not in self.__dict__:
+ if "_columns" not in self.__dict__:
self._init_collections()
self._populate_column_collection()
return self._columns.as_immutable()
@@ -706,14 +752,16 @@ class FromClause(Selectable):
self._populate_column_collection()
return self.foreign_keys
- c = property(attrgetter('columns'),
- doc="An alias for the :attr:`.columns` attribute.")
- _select_iterable = property(attrgetter('columns'))
+ c = property(
+ attrgetter("columns"),
+ doc="An alias for the :attr:`.columns` attribute.",
+ )
+ _select_iterable = property(attrgetter("columns"))
def _init_collections(self):
- assert '_columns' not in self.__dict__
- assert 'primary_key' not in self.__dict__
- assert 'foreign_keys' not in self.__dict__
+ assert "_columns" not in self.__dict__
+ assert "primary_key" not in self.__dict__
+ assert "foreign_keys" not in self.__dict__
self._columns = ColumnCollection()
self.primary_key = ColumnSet()
@@ -721,7 +769,7 @@ class FromClause(Selectable):
@property
def _cols_populated(self):
- return '_columns' in self.__dict__
+ return "_columns" in self.__dict__
def _populate_column_collection(self):
"""Called on subclasses to establish the .c collection.
@@ -758,8 +806,7 @@ class FromClause(Selectable):
"""
if not self._cols_populated:
return None
- elif (column.key in self.columns and
- self.columns[column.key] is column):
+ elif column.key in self.columns and self.columns[column.key] is column:
return column
else:
return None
@@ -780,7 +827,8 @@ class Join(FromClause):
:meth:`.FromClause.join`
"""
- __visit_name__ = 'join'
+
+ __visit_name__ = "join"
_is_join = True
@@ -829,8 +877,9 @@ class Join(FromClause):
return cls(left, right, onclause, isouter=True, full=full)
@classmethod
- def _create_join(cls, left, right, onclause=None, isouter=False,
- full=False):
+ def _create_join(
+ cls, left, right, onclause=None, isouter=False, full=False
+ ):
"""Produce a :class:`.Join` object, given two :class:`.FromClause`
expressions.
@@ -882,26 +931,34 @@ class Join(FromClause):
self.left.description,
id(self.left),
self.right.description,
- id(self.right))
+ id(self.right),
+ )
def is_derived_from(self, fromclause):
- return fromclause is self or \
- self.left.is_derived_from(fromclause) or \
- self.right.is_derived_from(fromclause)
+ return (
+ fromclause is self
+ or self.left.is_derived_from(fromclause)
+ or self.right.is_derived_from(fromclause)
+ )
def self_group(self, against=None):
return FromGrouping(self)
@util.dependencies("sqlalchemy.sql.util")
def _populate_column_collection(self, sqlutil):
- columns = [c for c in self.left.columns] + \
- [c for c in self.right.columns]
+ columns = [c for c in self.left.columns] + [
+ c for c in self.right.columns
+ ]
- self.primary_key.extend(sqlutil.reduce_columns(
- (c for c in columns if c.primary_key), self.onclause))
+ self.primary_key.extend(
+ sqlutil.reduce_columns(
+ (c for c in columns if c.primary_key), self.onclause
+ )
+ )
self._columns.update((col._label, col) for col in columns)
- self.foreign_keys.update(itertools.chain(
- *[col.foreign_keys for col in columns]))
+ self.foreign_keys.update(
+ itertools.chain(*[col.foreign_keys for col in columns])
+ )
def _refresh_for_new_column(self, column):
col = self.left._refresh_for_new_column(column)
@@ -933,9 +990,14 @@ class Join(FromClause):
return self._join_condition(left, right, a_subset=left_right)
@classmethod
- def _join_condition(cls, a, b, ignore_nonexistent_tables=False,
- a_subset=None,
- consider_as_foreign_keys=None):
+ def _join_condition(
+ cls,
+ a,
+ b,
+ ignore_nonexistent_tables=False,
+ a_subset=None,
+ consider_as_foreign_keys=None,
+ ):
"""create a join condition between two tables or selectables.
e.g.::
@@ -963,26 +1025,31 @@ class Join(FromClause):
"""
constraints = cls._joincond_scan_left_right(
- a, a_subset, b, consider_as_foreign_keys)
+ a, a_subset, b, consider_as_foreign_keys
+ )
if len(constraints) > 1:
cls._joincond_trim_constraints(
- a, b, constraints, consider_as_foreign_keys)
+ a, b, constraints, consider_as_foreign_keys
+ )
if len(constraints) == 0:
if isinstance(b, FromGrouping):
- hint = " Perhaps you meant to convert the right side to a "\
+ hint = (
+ " Perhaps you meant to convert the right side to a "
"subquery using alias()?"
+ )
else:
hint = ""
raise exc.NoForeignKeysError(
"Can't find any foreign key relationships "
- "between '%s' and '%s'.%s" %
- (a.description, b.description, hint))
+ "between '%s' and '%s'.%s"
+ % (a.description, b.description, hint)
+ )
crit = [(x == y) for x, y in list(constraints.values())[0]]
if len(crit) == 1:
- return (crit[0])
+ return crit[0]
else:
return and_(*crit)
@@ -994,24 +1061,30 @@ class Join(FromClause):
left_right = None
constraints = cls._joincond_scan_left_right(
- a=left, b=right, a_subset=left_right,
- consider_as_foreign_keys=consider_as_foreign_keys)
+ a=left,
+ b=right,
+ a_subset=left_right,
+ consider_as_foreign_keys=consider_as_foreign_keys,
+ )
return bool(constraints)
@classmethod
def _joincond_scan_left_right(
- cls, a, a_subset, b, consider_as_foreign_keys):
+ cls, a, a_subset, b, consider_as_foreign_keys
+ ):
constraints = collections.defaultdict(list)
for left in (a_subset, a):
if left is None:
continue
for fk in sorted(
- b.foreign_keys,
- key=lambda fk: fk.parent._creation_order):
- if consider_as_foreign_keys is not None and \
- fk.parent not in consider_as_foreign_keys:
+ b.foreign_keys, key=lambda fk: fk.parent._creation_order
+ ):
+ if (
+ consider_as_foreign_keys is not None
+ and fk.parent not in consider_as_foreign_keys
+ ):
continue
try:
col = fk.get_referent(left)
@@ -1025,10 +1098,12 @@ class Join(FromClause):
constraints[fk.constraint].append((col, fk.parent))
if left is not b:
for fk in sorted(
- left.foreign_keys,
- key=lambda fk: fk.parent._creation_order):
- if consider_as_foreign_keys is not None and \
- fk.parent not in consider_as_foreign_keys:
+ left.foreign_keys, key=lambda fk: fk.parent._creation_order
+ ):
+ if (
+ consider_as_foreign_keys is not None
+ and fk.parent not in consider_as_foreign_keys
+ ):
continue
try:
col = fk.get_referent(b)
@@ -1046,14 +1121,16 @@ class Join(FromClause):
@classmethod
def _joincond_trim_constraints(
- cls, a, b, constraints, consider_as_foreign_keys):
+ cls, a, b, constraints, consider_as_foreign_keys
+ ):
# more than one constraint matched. narrow down the list
# to include just those FKCs that match exactly to
# "consider_as_foreign_keys".
if consider_as_foreign_keys:
for const in list(constraints):
if set(f.parent for f in const.elements) != set(
- consider_as_foreign_keys):
+ consider_as_foreign_keys
+ ):
del constraints[const]
# if still multiple constraints, but
@@ -1070,8 +1147,8 @@ class Join(FromClause):
"tables have more than one foreign key "
"constraint relationship between them. "
"Please specify the 'onclause' of this "
- "join explicitly." % (a.description, b.description))
-
+ "join explicitly." % (a.description, b.description)
+ )
def select(self, whereclause=None, **kwargs):
r"""Create a :class:`.Select` from this :class:`.Join`.
@@ -1200,27 +1277,37 @@ class Join(FromClause):
"""
if flat:
assert name is None, "Can't send name argument with flat"
- left_a, right_a = self.left.alias(flat=True), \
- self.right.alias(flat=True)
- adapter = sqlutil.ClauseAdapter(left_a).\
- chain(sqlutil.ClauseAdapter(right_a))
+ left_a, right_a = (
+ self.left.alias(flat=True),
+ self.right.alias(flat=True),
+ )
+ adapter = sqlutil.ClauseAdapter(left_a).chain(
+ sqlutil.ClauseAdapter(right_a)
+ )
- return left_a.join(right_a, adapter.traverse(self.onclause),
- isouter=self.isouter, full=self.full)
+ return left_a.join(
+ right_a,
+ adapter.traverse(self.onclause),
+ isouter=self.isouter,
+ full=self.full,
+ )
else:
return self.select(use_labels=True, correlate=False).alias(name)
@property
def _hide_froms(self):
- return itertools.chain(*[_from_objects(x.left, x.right)
- for x in self._cloned_set])
+ return itertools.chain(
+ *[_from_objects(x.left, x.right) for x in self._cloned_set]
+ )
@property
def _from_objects(self):
- return [self] + \
- self.onclause._from_objects + \
- self.left._from_objects + \
- self.right._from_objects
+ return (
+ [self]
+ + self.onclause._from_objects
+ + self.left._from_objects
+ + self.right._from_objects
+ )
class Alias(FromClause):
@@ -1236,7 +1323,7 @@ class Alias(FromClause):
"""
- __visit_name__ = 'alias'
+ __visit_name__ = "alias"
named_with_column = True
_is_from_container = True
@@ -1252,15 +1339,16 @@ class Alias(FromClause):
self.element = selectable
if name is None:
if self.original.named_with_column:
- name = getattr(self.original, 'name', None)
- name = _anonymous_label('%%(%d %s)s' % (id(self), name
- or 'anon'))
+ name = getattr(self.original, "name", None)
+ name = _anonymous_label("%%(%d %s)s" % (id(self), name or "anon"))
self.name = name
def self_group(self, against=None):
- if isinstance(against, CompoundSelect) and \
- isinstance(self.original, Select) and \
- self.original._needs_parens_for_grouping():
+ if (
+ isinstance(against, CompoundSelect)
+ and isinstance(self.original, Select)
+ and self.original._needs_parens_for_grouping()
+ ):
return FromGrouping(self)
return super(Alias, self).self_group(against=against)
@@ -1270,14 +1358,15 @@ class Alias(FromClause):
if util.py3k:
return self.name
else:
- return self.name.encode('ascii', 'backslashreplace')
+ return self.name.encode("ascii", "backslashreplace")
def as_scalar(self):
try:
return self.element.as_scalar()
except AttributeError:
- raise AttributeError("Element %s does not support "
- "'as_scalar()'" % self.element)
+ raise AttributeError(
+ "Element %s does not support " "'as_scalar()'" % self.element
+ )
def is_derived_from(self, fromclause):
if fromclause in self._cloned_set:
@@ -1344,7 +1433,7 @@ class Lateral(Alias):
"""
- __visit_name__ = 'lateral'
+ __visit_name__ = "lateral"
_is_lateral = True
@@ -1363,11 +1452,9 @@ class TableSample(Alias):
"""
- __visit_name__ = 'tablesample'
+ __visit_name__ = "tablesample"
- def __init__(self, selectable, sampling,
- name=None,
- seed=None):
+ def __init__(self, selectable, sampling, name=None, seed=None):
self.sampling = sampling
self.seed = seed
super(TableSample, self).__init__(selectable, name=name)
@@ -1390,14 +1477,18 @@ class CTE(Generative, HasSuffixes, Alias):
.. versionadded:: 0.7.6
"""
- __visit_name__ = 'cte'
-
- def __init__(self, selectable,
- name=None,
- recursive=False,
- _cte_alias=None,
- _restates=frozenset(),
- _suffixes=None):
+
+ __visit_name__ = "cte"
+
+ def __init__(
+ self,
+ selectable,
+ name=None,
+ recursive=False,
+ _cte_alias=None,
+ _restates=frozenset(),
+ _suffixes=None,
+ ):
self.recursive = recursive
self._cte_alias = _cte_alias
self._restates = _restates
@@ -1409,9 +1500,9 @@ class CTE(Generative, HasSuffixes, Alias):
super(CTE, self)._copy_internals(clone, **kw)
if self._cte_alias is not None:
self._cte_alias = clone(self._cte_alias, **kw)
- self._restates = frozenset([
- clone(elem, **kw) for elem in self._restates
- ])
+ self._restates = frozenset(
+ [clone(elem, **kw) for elem in self._restates]
+ )
@util.dependencies("sqlalchemy.sql.dml")
def _populate_column_collection(self, dml):
@@ -1428,7 +1519,7 @@ class CTE(Generative, HasSuffixes, Alias):
name=name,
recursive=self.recursive,
_cte_alias=self,
- _suffixes=self._suffixes
+ _suffixes=self._suffixes,
)
def union(self, other):
@@ -1437,7 +1528,7 @@ class CTE(Generative, HasSuffixes, Alias):
name=self.name,
recursive=self.recursive,
_restates=self._restates.union([self]),
- _suffixes=self._suffixes
+ _suffixes=self._suffixes,
)
def union_all(self, other):
@@ -1446,7 +1537,7 @@ class CTE(Generative, HasSuffixes, Alias):
name=self.name,
recursive=self.recursive,
_restates=self._restates.union([self]),
- _suffixes=self._suffixes
+ _suffixes=self._suffixes,
)
@@ -1620,7 +1711,8 @@ class HasCTE(object):
class FromGrouping(FromClause):
"""Represent a grouping of a FROM clause"""
- __visit_name__ = 'grouping'
+
+ __visit_name__ = "grouping"
def __init__(self, element):
self.element = element
@@ -1651,7 +1743,7 @@ class FromGrouping(FromClause):
return self.element._hide_froms
def get_children(self, **kwargs):
- return self.element,
+ return (self.element,)
def _copy_internals(self, clone=_clone, **kw):
self.element = clone(self.element, **kw)
@@ -1664,10 +1756,10 @@ class FromGrouping(FromClause):
return getattr(self.element, attr)
def __getstate__(self):
- return {'element': self.element}
+ return {"element": self.element}
def __setstate__(self, state):
- self.element = state['element']
+ self.element = state["element"]
class TableClause(Immutable, FromClause):
@@ -1699,7 +1791,7 @@ class TableClause(Immutable, FromClause):
"""
- __visit_name__ = 'table'
+ __visit_name__ = "table"
named_with_column = True
@@ -1744,7 +1836,7 @@ class TableClause(Immutable, FromClause):
if util.py3k:
return self.name
else:
- return self.name.encode('ascii', 'backslashreplace')
+ return self.name.encode("ascii", "backslashreplace")
def append_column(self, c):
self._columns[c.key] = c
@@ -1773,7 +1865,8 @@ class TableClause(Immutable, FromClause):
@util.dependencies("sqlalchemy.sql.dml")
def update(
- self, dml, whereclause=None, values=None, inline=False, **kwargs):
+ self, dml, whereclause=None, values=None, inline=False, **kwargs
+ ):
"""Generate an :func:`.update` construct against this
:class:`.TableClause`.
@@ -1785,8 +1878,13 @@ class TableClause(Immutable, FromClause):
"""
- return dml.Update(self, whereclause=whereclause,
- values=values, inline=inline, **kwargs)
+ return dml.Update(
+ self,
+ whereclause=whereclause,
+ values=values,
+ inline=inline,
+ **kwargs
+ )
@util.dependencies("sqlalchemy.sql.dml")
def delete(self, dml, whereclause=None, **kwargs):
@@ -1809,7 +1907,6 @@ class TableClause(Immutable, FromClause):
class ForUpdateArg(ClauseElement):
-
@classmethod
def parse_legacy_select(self, arg):
"""Parse the for_update argument of :func:`.select`.
@@ -1836,11 +1933,11 @@ class ForUpdateArg(ClauseElement):
return None
nowait = read = False
- if arg == 'nowait':
+ if arg == "nowait":
nowait = True
- elif arg == 'read':
+ elif arg == "read":
read = True
- elif arg == 'read_nowait':
+ elif arg == "read_nowait":
read = nowait = True
elif arg is not True:
raise exc.ArgumentError("Unknown for_update argument: %r" % arg)
@@ -1860,12 +1957,12 @@ class ForUpdateArg(ClauseElement):
def __eq__(self, other):
return (
- isinstance(other, ForUpdateArg) and
- other.nowait == self.nowait and
- other.read == self.read and
- other.skip_locked == self.skip_locked and
- other.key_share == self.key_share and
- other.of is self.of
+ isinstance(other, ForUpdateArg)
+ and other.nowait == self.nowait
+ and other.read == self.read
+ and other.skip_locked == self.skip_locked
+ and other.key_share == self.key_share
+ and other.of is self.of
)
def __hash__(self):
@@ -1876,8 +1973,13 @@ class ForUpdateArg(ClauseElement):
self.of = [clone(col, **kw) for col in self.of]
def __init__(
- self, nowait=False, read=False, of=None,
- skip_locked=False, key_share=False):
+ self,
+ nowait=False,
+ read=False,
+ of=None,
+ skip_locked=False,
+ key_share=False,
+ ):
"""Represents arguments specified to :meth:`.Select.for_update`.
.. versionadded:: 0.9.0
@@ -1889,8 +1991,9 @@ class ForUpdateArg(ClauseElement):
self.skip_locked = skip_locked
self.key_share = key_share
if of is not None:
- self.of = [_interpret_as_column_or_from(elem)
- for elem in util.to_list(of)]
+ self.of = [
+ _interpret_as_column_or_from(elem) for elem in util.to_list(of)
+ ]
else:
self.of = None
@@ -1930,17 +2033,20 @@ class SelectBase(HasCTE, Executable, FromClause):
return self.as_scalar().label(name)
@_generative
- @util.deprecated('0.6',
- message="``autocommit()`` is deprecated. Use "
- ":meth:`.Executable.execution_options` with the "
- "'autocommit' flag.")
+ @util.deprecated(
+ "0.6",
+ message="``autocommit()`` is deprecated. Use "
+ ":meth:`.Executable.execution_options` with the "
+ "'autocommit' flag.",
+ )
def autocommit(self):
"""return a new selectable with the 'autocommit' flag set to
True.
"""
- self._execution_options = \
- self._execution_options.union({'autocommit': True})
+ self._execution_options = self._execution_options.union(
+ {"autocommit": True}
+ )
def _generate(self):
"""Override the default _generate() method to also clear out
@@ -1973,34 +2079,38 @@ class GenerativeSelect(SelectBase):
used for other SELECT-like objects, e.g. :class:`.TextAsFrom`.
"""
+
_order_by_clause = ClauseList()
_group_by_clause = ClauseList()
_limit_clause = None
_offset_clause = None
_for_update_arg = None
- def __init__(self,
- use_labels=False,
- for_update=False,
- limit=None,
- offset=None,
- order_by=None,
- group_by=None,
- bind=None,
- autocommit=None):
+ def __init__(
+ self,
+ use_labels=False,
+ for_update=False,
+ limit=None,
+ offset=None,
+ order_by=None,
+ group_by=None,
+ bind=None,
+ autocommit=None,
+ ):
self.use_labels = use_labels
if for_update is not False:
- self._for_update_arg = (ForUpdateArg.
- parse_legacy_select(for_update))
+ self._for_update_arg = ForUpdateArg.parse_legacy_select(for_update)
if autocommit is not None:
- util.warn_deprecated('autocommit on select() is '
- 'deprecated. Use .execution_options(a'
- 'utocommit=True)')
- self._execution_options = \
- self._execution_options.union(
- {'autocommit': autocommit})
+ util.warn_deprecated(
+ "autocommit on select() is "
+ "deprecated. Use .execution_options(a"
+ "utocommit=True)"
+ )
+ self._execution_options = self._execution_options.union(
+ {"autocommit": autocommit}
+ )
if limit is not None:
self._limit_clause = _offset_or_limit_clause(limit)
if offset is not None:
@@ -2010,11 +2120,13 @@ class GenerativeSelect(SelectBase):
if order_by is not None:
self._order_by_clause = ClauseList(
*util.to_list(order_by),
- _literal_as_text=_literal_and_labels_as_label_reference)
+ _literal_as_text=_literal_and_labels_as_label_reference
+ )
if group_by is not None:
self._group_by_clause = ClauseList(
*util.to_list(group_by),
- _literal_as_text=_literal_as_label_reference)
+ _literal_as_text=_literal_as_label_reference
+ )
@property
def for_update(self):
@@ -2030,8 +2142,14 @@ class GenerativeSelect(SelectBase):
self._for_update_arg = ForUpdateArg.parse_legacy_select(value)
@_generative
- def with_for_update(self, nowait=False, read=False, of=None,
- skip_locked=False, key_share=False):
+ def with_for_update(
+ self,
+ nowait=False,
+ read=False,
+ of=None,
+ skip_locked=False,
+ key_share=False,
+ ):
"""Specify a ``FOR UPDATE`` clause for this :class:`.GenerativeSelect`.
E.g.::
@@ -2079,9 +2197,13 @@ class GenerativeSelect(SelectBase):
.. versionadded:: 1.1.0
"""
- self._for_update_arg = ForUpdateArg(nowait=nowait, read=read, of=of,
- skip_locked=skip_locked,
- key_share=key_share)
+ self._for_update_arg = ForUpdateArg(
+ nowait=nowait,
+ read=read,
+ of=of,
+ skip_locked=skip_locked,
+ key_share=key_share,
+ )
@_generative
def apply_labels(self):
@@ -2209,11 +2331,12 @@ class GenerativeSelect(SelectBase):
if len(clauses) == 1 and clauses[0] is None:
self._order_by_clause = ClauseList()
else:
- if getattr(self, '_order_by_clause', None) is not None:
+ if getattr(self, "_order_by_clause", None) is not None:
clauses = list(self._order_by_clause) + list(clauses)
self._order_by_clause = ClauseList(
*clauses,
- _literal_as_text=_literal_and_labels_as_label_reference)
+ _literal_as_text=_literal_and_labels_as_label_reference
+ )
def append_group_by(self, *clauses):
"""Append the given GROUP BY criterion applied to this selectable.
@@ -2228,10 +2351,11 @@ class GenerativeSelect(SelectBase):
if len(clauses) == 1 and clauses[0] is None:
self._group_by_clause = ClauseList()
else:
- if getattr(self, '_group_by_clause', None) is not None:
+ if getattr(self, "_group_by_clause", None) is not None:
clauses = list(self._group_by_clause) + list(clauses)
self._group_by_clause = ClauseList(
- *clauses, _literal_as_text=_literal_as_label_reference)
+ *clauses, _literal_as_text=_literal_as_label_reference
+ )
@property
def _label_resolve_dict(self):
@@ -2265,19 +2389,19 @@ class CompoundSelect(GenerativeSelect):
"""
- __visit_name__ = 'compound_select'
+ __visit_name__ = "compound_select"
- UNION = util.symbol('UNION')
- UNION_ALL = util.symbol('UNION ALL')
- EXCEPT = util.symbol('EXCEPT')
- EXCEPT_ALL = util.symbol('EXCEPT ALL')
- INTERSECT = util.symbol('INTERSECT')
- INTERSECT_ALL = util.symbol('INTERSECT ALL')
+ UNION = util.symbol("UNION")
+ UNION_ALL = util.symbol("UNION ALL")
+ EXCEPT = util.symbol("EXCEPT")
+ EXCEPT_ALL = util.symbol("EXCEPT ALL")
+ INTERSECT = util.symbol("INTERSECT")
+ INTERSECT_ALL = util.symbol("INTERSECT ALL")
_is_from_container = True
def __init__(self, keyword, *selects, **kwargs):
- self._auto_correlate = kwargs.pop('correlate', False)
+ self._auto_correlate = kwargs.pop("correlate", False)
self.keyword = keyword
self.selects = []
@@ -2291,12 +2415,16 @@ class CompoundSelect(GenerativeSelect):
numcols = len(s.c._all_columns)
elif len(s.c._all_columns) != numcols:
raise exc.ArgumentError(
- 'All selectables passed to '
- 'CompoundSelect must have identical numbers of '
- 'columns; select #%d has %d columns, select '
- '#%d has %d' %
- (1, len(self.selects[0].c._all_columns),
- n + 1, len(s.c._all_columns))
+ "All selectables passed to "
+ "CompoundSelect must have identical numbers of "
+ "columns; select #%d has %d columns, select "
+ "#%d has %d"
+ % (
+ 1,
+ len(self.selects[0].c._all_columns),
+ n + 1,
+ len(s.c._all_columns),
+ )
)
self.selects.append(s.self_group(against=self))
@@ -2305,9 +2433,7 @@ class CompoundSelect(GenerativeSelect):
@property
def _label_resolve_dict(self):
- d = dict(
- (c.key, c) for c in self.c
- )
+ d = dict((c.key, c) for c in self.c)
return d, d, d
@classmethod
@@ -2416,8 +2542,7 @@ class CompoundSelect(GenerativeSelect):
:func:`select`.
"""
- return CompoundSelect(
- CompoundSelect.INTERSECT_ALL, *selects, **kwargs)
+ return CompoundSelect(CompoundSelect.INTERSECT_ALL, *selects, **kwargs)
def _scalar_type(self):
return self.selects[0]._scalar_type()
@@ -2445,8 +2570,10 @@ class CompoundSelect(GenerativeSelect):
# those fks too.
proxy = cols[0]._make_proxy(
- self, name=cols[0]._label if self.use_labels else None,
- key=cols[0]._key_label if self.use_labels else None)
+ self,
+ name=cols[0]._label if self.use_labels else None,
+ key=cols[0]._key_label if self.use_labels else None,
+ )
# hand-construct the "_proxies" collection to include all
# derived columns place a 'weight' annotation corresponding
@@ -2455,7 +2582,8 @@ class CompoundSelect(GenerativeSelect):
# conflicts
proxy._proxies = [
- c._annotate({'weight': i + 1}) for (i, c) in enumerate(cols)]
+ c._annotate({"weight": i + 1}) for (i, c) in enumerate(cols)
+ ]
def _refresh_for_new_column(self, column):
for s in self.selects:
@@ -2464,25 +2592,32 @@ class CompoundSelect(GenerativeSelect):
if not self._cols_populated:
return None
- raise NotImplementedError("CompoundSelect constructs don't support "
- "addition of columns to underlying "
- "selectables")
+ raise NotImplementedError(
+ "CompoundSelect constructs don't support "
+ "addition of columns to underlying "
+ "selectables"
+ )
def _copy_internals(self, clone=_clone, **kw):
super(CompoundSelect, self)._copy_internals(clone, **kw)
self._reset_exported()
self.selects = [clone(s, **kw) for s in self.selects]
- if hasattr(self, '_col_map'):
+ if hasattr(self, "_col_map"):
del self._col_map
for attr in (
- '_order_by_clause', '_group_by_clause', '_for_update_arg'):
+ "_order_by_clause",
+ "_group_by_clause",
+ "_for_update_arg",
+ ):
if getattr(self, attr) is not None:
setattr(self, attr, clone(getattr(self, attr), **kw))
def get_children(self, column_collections=True, **kwargs):
- return (column_collections and list(self.c) or []) \
- + [self._order_by_clause, self._group_by_clause] \
+ return (
+ (column_collections and list(self.c) or [])
+ + [self._order_by_clause, self._group_by_clause]
+ list(self.selects)
+ )
def bind(self):
if self._bind:
@@ -2496,6 +2631,7 @@ class CompoundSelect(GenerativeSelect):
def _set_bind(self, bind):
self._bind = bind
+
bind = property(bind, _set_bind)
@@ -2504,7 +2640,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
"""
- __visit_name__ = 'select'
+ __visit_name__ = "select"
_prefixes = ()
_suffixes = ()
@@ -2517,16 +2653,18 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
_memoized_property = SelectBase._memoized_property
_is_select = True
- def __init__(self,
- columns=None,
- whereclause=None,
- from_obj=None,
- distinct=False,
- having=None,
- correlate=True,
- prefixes=None,
- suffixes=None,
- **kwargs):
+ def __init__(
+ self,
+ columns=None,
+ whereclause=None,
+ from_obj=None,
+ distinct=False,
+ having=None,
+ correlate=True,
+ prefixes=None,
+ suffixes=None,
+ **kwargs
+ ):
"""Construct a new :class:`.Select`.
Similar functionality is also available via the
@@ -2729,22 +2867,23 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self._distinct = True
else:
self._distinct = [
- _literal_as_text(e)
- for e in util.to_list(distinct)
+ _literal_as_text(e) for e in util.to_list(distinct)
]
if from_obj is not None:
self._from_obj = util.OrderedSet(
- _interpret_as_from(f)
- for f in util.to_list(from_obj))
+ _interpret_as_from(f) for f in util.to_list(from_obj)
+ )
else:
self._from_obj = util.OrderedSet()
try:
cols_present = bool(columns)
except TypeError:
- raise exc.ArgumentError("columns argument to select() must "
- "be a Python list or other iterable")
+ raise exc.ArgumentError(
+ "columns argument to select() must "
+ "be a Python list or other iterable"
+ )
if cols_present:
self._raw_columns = []
@@ -2757,14 +2896,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self._raw_columns = []
if whereclause is not None:
- self._whereclause = _literal_as_text(
- whereclause).self_group(against=operators._asbool)
+ self._whereclause = _literal_as_text(whereclause).self_group(
+ against=operators._asbool
+ )
else:
self._whereclause = None
if having is not None:
- self._having = _literal_as_text(
- having).self_group(against=operators._asbool)
+ self._having = _literal_as_text(having).self_group(
+ against=operators._asbool
+ )
else:
self._having = None
@@ -2789,12 +2930,14 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
for item in itertools.chain(
_from_objects(*self._raw_columns),
_from_objects(self._whereclause)
- if self._whereclause is not None else (),
- self._from_obj
+ if self._whereclause is not None
+ else (),
+ self._from_obj,
):
if item is self:
raise exc.InvalidRequestError(
- "select() construct refers to itself as a FROM")
+ "select() construct refers to itself as a FROM"
+ )
if translate and item in translate:
item = translate[item]
if not seen.intersection(item._cloned_set):
@@ -2803,8 +2946,9 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
return froms
- def _get_display_froms(self, explicit_correlate_froms=None,
- implicit_correlate_froms=None):
+ def _get_display_froms(
+ self, explicit_correlate_froms=None, implicit_correlate_froms=None
+ ):
"""Return the full list of 'from' clauses to be displayed.
Takes into account a set of existing froms which may be
@@ -2815,17 +2959,17 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
"""
froms = self._froms
- toremove = set(itertools.chain(*[
- _expand_cloned(f._hide_froms)
- for f in froms]))
+ toremove = set(
+ itertools.chain(*[_expand_cloned(f._hide_froms) for f in froms])
+ )
if toremove:
# if we're maintaining clones of froms,
# add the copies out to the toremove list. only include
# clones that are lexical equivalents.
if self._from_cloned:
toremove.update(
- self._from_cloned[f] for f in
- toremove.intersection(self._from_cloned)
+ self._from_cloned[f]
+ for f in toremove.intersection(self._from_cloned)
if self._from_cloned[f]._is_lexical_equivalent(f)
)
# filter out to FROM clauses not in the list,
@@ -2836,41 +2980,53 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
to_correlate = self._correlate
if to_correlate:
froms = [
- f for f in froms if f not in
- _cloned_intersection(
+ f
+ for f in froms
+ if f
+ not in _cloned_intersection(
_cloned_intersection(
- froms, explicit_correlate_froms or ()),
- to_correlate
+ froms, explicit_correlate_froms or ()
+ ),
+ to_correlate,
)
]
if self._correlate_except is not None:
froms = [
- f for f in froms if f not in
- _cloned_difference(
+ f
+ for f in froms
+ if f
+ not in _cloned_difference(
_cloned_intersection(
- froms, explicit_correlate_froms or ()),
- self._correlate_except
+ froms, explicit_correlate_froms or ()
+ ),
+ self._correlate_except,
)
]
- if self._auto_correlate and \
- implicit_correlate_froms and \
- len(froms) > 1:
+ if (
+ self._auto_correlate
+ and implicit_correlate_froms
+ and len(froms) > 1
+ ):
froms = [
- f for f in froms if f not in
- _cloned_intersection(froms, implicit_correlate_froms)
+ f
+ for f in froms
+ if f
+ not in _cloned_intersection(froms, implicit_correlate_froms)
]
if not len(froms):
- raise exc.InvalidRequestError("Select statement '%s"
- "' returned no FROM clauses "
- "due to auto-correlation; "
- "specify correlate(<tables>) "
- "to control correlation "
- "manually." % self)
+ raise exc.InvalidRequestError(
+ "Select statement '%s"
+ "' returned no FROM clauses "
+ "due to auto-correlation; "
+ "specify correlate(<tables>) "
+ "to control correlation "
+ "manually." % self
+ )
return froms
@@ -2885,7 +3041,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
return self._get_display_froms()
- def with_statement_hint(self, text, dialect_name='*'):
+ def with_statement_hint(self, text, dialect_name="*"):
"""add a statement hint to this :class:`.Select`.
This method is similar to :meth:`.Select.with_hint` except that
@@ -2906,7 +3062,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
return self.with_hint(None, text, dialect_name)
@_generative
- def with_hint(self, selectable, text, dialect_name='*'):
+ def with_hint(self, selectable, text, dialect_name="*"):
r"""Add an indexing or other executional context hint for the given
selectable to this :class:`.Select`.
@@ -2940,17 +3096,18 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
"""
if selectable is None:
- self._statement_hints += ((dialect_name, text), )
+ self._statement_hints += ((dialect_name, text),)
else:
- self._hints = self._hints.union(
- {(selectable, dialect_name): text})
+ self._hints = self._hints.union({(selectable, dialect_name): text})
@property
def type(self):
- raise exc.InvalidRequestError("Select objects don't have a type. "
- "Call as_scalar() on this Select "
- "object to return a 'scalar' version "
- "of this Select.")
+ raise exc.InvalidRequestError(
+ "Select objects don't have a type. "
+ "Call as_scalar() on this Select "
+ "object to return a 'scalar' version "
+ "of this Select."
+ )
@_memoized_property.method
def locate_all_froms(self):
@@ -2977,10 +3134,13 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
with_cols = dict(
(c._resolve_label or c._label or c.key, c)
for c in _select_iterables(self._raw_columns)
- if c._allow_label_resolve)
+ if c._allow_label_resolve
+ )
only_froms = dict(
- (c.key, c) for c in
- _select_iterables(self.froms) if c._allow_label_resolve)
+ (c.key, c)
+ for c in _select_iterables(self.froms)
+ if c._allow_label_resolve
+ )
only_cols = with_cols.copy()
for key, value in only_froms.items():
with_cols.setdefault(key, value)
@@ -3011,11 +3171,13 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
# gets cleared on each generation. previously we were "baking"
# _froms into self._from_obj.
self._from_cloned = from_cloned = dict(
- (f, clone(f, **kw)) for f in self._from_obj.union(self._froms))
+ (f, clone(f, **kw)) for f in self._from_obj.union(self._froms)
+ )
# 3. update persistent _from_obj with the cloned versions.
- self._from_obj = util.OrderedSet(from_cloned[f] for f in
- self._from_obj)
+ self._from_obj = util.OrderedSet(
+ from_cloned[f] for f in self._from_obj
+ )
# the _correlate collection is done separately, what can happen
# here is the same item is _correlate as in _from_obj but the
@@ -3023,16 +3185,22 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
# RelationshipProperty.Comparator._criterion_exists() does
# this). Also keep _correlate liberally open with its previous
# contents, as this set is used for matching, not rendering.
- self._correlate = set(clone(f) for f in
- self._correlate).union(self._correlate)
+ self._correlate = set(clone(f) for f in self._correlate).union(
+ self._correlate
+ )
# 4. clone other things. The difficulty here is that Column
# objects are not actually cloned, and refer to their original
# .table, resulting in the wrong "from" parent after a clone
# operation. Hence _from_cloned and _from_obj supersede what is
# present here.
self._raw_columns = [clone(c, **kw) for c in self._raw_columns]
- for attr in '_whereclause', '_having', '_order_by_clause', \
- '_group_by_clause', '_for_update_arg':
+ for attr in (
+ "_whereclause",
+ "_having",
+ "_order_by_clause",
+ "_group_by_clause",
+ "_for_update_arg",
+ ):
if getattr(self, attr) is not None:
setattr(self, attr, clone(getattr(self, attr), **kw))
@@ -3043,12 +3211,21 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
def get_children(self, column_collections=True, **kwargs):
"""return child elements as per the ClauseElement specification."""
- return (column_collections and list(self.columns) or []) + \
- self._raw_columns + list(self._froms) + \
- [x for x in
- (self._whereclause, self._having,
- self._order_by_clause, self._group_by_clause)
- if x is not None]
+ return (
+ (column_collections and list(self.columns) or [])
+ + self._raw_columns
+ + list(self._froms)
+ + [
+ x
+ for x in (
+ self._whereclause,
+ self._having,
+ self._order_by_clause,
+ self._group_by_clause,
+ )
+ if x is not None
+ ]
+ )
@_generative
def column(self, column):
@@ -3094,7 +3271,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
sqlutil.reduce_columns(
self.inner_columns,
only_synonyms=only_synonyms,
- *(self._whereclause, ) + tuple(self._from_obj)
+ *(self._whereclause,) + tuple(self._from_obj)
)
)
@@ -3307,7 +3484,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self._correlate = ()
else:
self._correlate = set(self._correlate).union(
- _interpret_as_from(f) for f in fromclauses)
+ _interpret_as_from(f) for f in fromclauses
+ )
@_generative
def correlate_except(self, *fromclauses):
@@ -3349,7 +3527,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self._correlate_except = ()
else:
self._correlate_except = set(self._correlate_except or ()).union(
- _interpret_as_from(f) for f in fromclauses)
+ _interpret_as_from(f) for f in fromclauses
+ )
def append_correlation(self, fromclause):
"""append the given correlation expression to this select()
@@ -3363,7 +3542,8 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self._auto_correlate = False
self._correlate = set(self._correlate).union(
- _interpret_as_from(f) for f in fromclause)
+ _interpret_as_from(f) for f in fromclause
+ )
def append_column(self, column):
"""append the given column expression to the columns clause of this
@@ -3415,8 +3595,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
"""
self._reset_exported()
- self._whereclause = and_(
- True_._ifnone(self._whereclause), whereclause)
+ self._whereclause = and_(True_._ifnone(self._whereclause), whereclause)
def append_having(self, having):
"""append the given expression to this select() construct's HAVING
@@ -3463,19 +3642,17 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
return [
name_for_col(c)
- for c in util.unique_list(
- _select_iterables(self._raw_columns))
+ for c in util.unique_list(_select_iterables(self._raw_columns))
]
else:
return [
(None, c)
- for c in util.unique_list(
- _select_iterables(self._raw_columns))
+ for c in util.unique_list(_select_iterables(self._raw_columns))
]
def _populate_column_collection(self):
for name, c in self._columns_plus_names:
- if not hasattr(c, '_make_proxy'):
+ if not hasattr(c, "_make_proxy"):
continue
if name is None:
key = None
@@ -3486,9 +3663,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
else:
key = None
- c._make_proxy(self, key=key,
- name=name,
- name_is_truncatable=True)
+ c._make_proxy(self, key=key, name=name, name_is_truncatable=True)
def _refresh_for_new_column(self, column):
for fromclause in self._froms:
@@ -3501,15 +3676,16 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
self,
name=col._label if self.use_labels else None,
key=col._key_label if self.use_labels else None,
- name_is_truncatable=True)
+ name_is_truncatable=True,
+ )
return None
return None
def _needs_parens_for_grouping(self):
return (
- self._limit_clause is not None or
- self._offset_clause is not None or
- bool(self._order_by_clause.clauses)
+ self._limit_clause is not None
+ or self._offset_clause is not None
+ or bool(self._order_by_clause.clauses)
)
def self_group(self, against=None):
@@ -3521,8 +3697,10 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
expressions and should not require explicit use.
"""
- if isinstance(against, CompoundSelect) and \
- not self._needs_parens_for_grouping():
+ if (
+ isinstance(against, CompoundSelect)
+ and not self._needs_parens_for_grouping()
+ ):
return self
return FromGrouping(self)
@@ -3586,6 +3764,7 @@ class Select(HasPrefixes, HasSuffixes, GenerativeSelect):
def _set_bind(self, bind):
self._bind = bind
+
bind = property(bind, _set_bind)
@@ -3600,9 +3779,12 @@ class ScalarSelect(Generative, Grouping):
@property
def columns(self):
- raise exc.InvalidRequestError('Scalar Select expression has no '
- 'columns; use this object directly '
- 'within a column-level expression.')
+ raise exc.InvalidRequestError(
+ "Scalar Select expression has no "
+ "columns; use this object directly "
+ "within a column-level expression."
+ )
+
c = columns
@_generative
@@ -3621,6 +3803,7 @@ class Exists(UnaryExpression):
"""Represent an ``EXISTS`` clause.
"""
+
__visit_name__ = UnaryExpression.__visit_name__
_from_objects = []
@@ -3646,12 +3829,16 @@ class Exists(UnaryExpression):
s = args[0]
else:
if not args:
- args = ([literal_column('*')],)
+ args = ([literal_column("*")],)
s = Select(*args, **kwargs).as_scalar().self_group()
- UnaryExpression.__init__(self, s, operator=operators.exists,
- type_=type_api.BOOLEANTYPE,
- wraps_column_expression=True)
+ UnaryExpression.__init__(
+ self,
+ s,
+ operator=operators.exists,
+ type_=type_api.BOOLEANTYPE,
+ wraps_column_expression=True,
+ )
def select(self, whereclause=None, **params):
return Select([self], whereclause, **params)
@@ -3706,6 +3893,7 @@ class TextAsFrom(SelectBase):
:meth:`.TextClause.columns`
"""
+
__visit_name__ = "text_as_from"
_textual = True
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index c5708940b..61fc6d3c9 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -15,10 +15,21 @@ import collections
import json
from . import elements
-from .type_api import TypeEngine, TypeDecorator, to_instance, Variant, \
- Emulated, NativeForEmulated
-from .elements import quoted_name, TypeCoerce as type_coerce, _defer_name, \
- Slice, _literal_as_binds
+from .type_api import (
+ TypeEngine,
+ TypeDecorator,
+ to_instance,
+ Variant,
+ Emulated,
+ NativeForEmulated,
+)
+from .elements import (
+ quoted_name,
+ TypeCoerce as type_coerce,
+ _defer_name,
+ Slice,
+ _literal_as_binds,
+)
from .. import exc, util, processors
from .base import _bind_or_error, SchemaEventTarget
from . import operators
@@ -51,14 +62,15 @@ class _LookupExpressionAdapter(object):
def _adapt_expression(self, op, other_comparator):
othertype = other_comparator.type._type_affinity
lookup = self.type._expression_adaptations.get(
- op, self._blank_dict).get(
- othertype, self.type)
+ op, self._blank_dict
+ ).get(othertype, self.type)
if lookup is othertype:
return (op, other_comparator.type)
elif lookup is self.type._type_affinity:
return (op, self.type)
else:
return (op, to_instance(lookup))
+
comparator_factory = Comparator
@@ -68,17 +80,16 @@ class Concatenable(object):
typically strings."""
class Comparator(TypeEngine.Comparator):
-
def _adapt_expression(self, op, other_comparator):
- if (op is operators.add and
- isinstance(
- other_comparator,
- (Concatenable.Comparator, NullType.Comparator)
- )):
+ if op is operators.add and isinstance(
+ other_comparator,
+ (Concatenable.Comparator, NullType.Comparator),
+ ):
return operators.concat_op, self.expr.type
else:
return super(Concatenable.Comparator, self)._adapt_expression(
- op, other_comparator)
+ op, other_comparator
+ )
comparator_factory = Comparator
@@ -94,17 +105,15 @@ class Indexable(object):
"""
class Comparator(TypeEngine.Comparator):
-
def _setup_getitem(self, index):
raise NotImplementedError()
def __getitem__(self, index):
- adjusted_op, adjusted_right_expr, result_type = \
- self._setup_getitem(index)
+ adjusted_op, adjusted_right_expr, result_type = self._setup_getitem(
+ index
+ )
return self.operate(
- adjusted_op,
- adjusted_right_expr,
- result_type=result_type
+ adjusted_op, adjusted_right_expr, result_type=result_type
)
comparator_factory = Comparator
@@ -124,13 +133,16 @@ class String(Concatenable, TypeEngine):
"""
- __visit_name__ = 'string'
+ __visit_name__ = "string"
- def __init__(self, length=None, collation=None,
- convert_unicode=False,
- unicode_error=None,
- _warn_on_bytestring=False
- ):
+ def __init__(
+ self,
+ length=None,
+ collation=None,
+ convert_unicode=False,
+ unicode_error=None,
+ _warn_on_bytestring=False,
+ ):
"""
Create a string-holding type.
@@ -207,9 +219,10 @@ class String(Concatenable, TypeEngine):
strings from a column with varied or corrupted encodings.
"""
- if unicode_error is not None and convert_unicode != 'force':
- raise exc.ArgumentError("convert_unicode must be 'force' "
- "when unicode_error is set.")
+ if unicode_error is not None and convert_unicode != "force":
+ raise exc.ArgumentError(
+ "convert_unicode must be 'force' " "when unicode_error is set."
+ )
self.length = length
self.collation = collation
@@ -222,23 +235,29 @@ class String(Concatenable, TypeEngine):
value = value.replace("'", "''")
if dialect.identifier_preparer._double_percents:
- value = value.replace('%', '%%')
+ value = value.replace("%", "%%")
return "'%s'" % value
+
return process
def bind_processor(self, dialect):
if self.convert_unicode or dialect.convert_unicode:
- if dialect.supports_unicode_binds and \
- self.convert_unicode != 'force':
+ if (
+ dialect.supports_unicode_binds
+ and self.convert_unicode != "force"
+ ):
if self._warn_on_bytestring:
+
def process(value):
if isinstance(value, util.binary_type):
util.warn_limited(
"Unicode type received non-unicode "
"bind param value %r.",
- (util.ellipses_string(value),))
+ (util.ellipses_string(value),),
+ )
return value
+
return process
else:
return None
@@ -253,29 +272,34 @@ class String(Concatenable, TypeEngine):
util.warn_limited(
"Unicode type received non-unicode bind "
"param value %r.",
- (util.ellipses_string(value),))
+ (util.ellipses_string(value),),
+ )
return value
+
return process
else:
return None
def result_processor(self, dialect, coltype):
wants_unicode = self.convert_unicode or dialect.convert_unicode
- needs_convert = wants_unicode and \
- (dialect.returns_unicode_strings is not True or
- self.convert_unicode in ('force', 'force_nocheck'))
+ needs_convert = wants_unicode and (
+ dialect.returns_unicode_strings is not True
+ or self.convert_unicode in ("force", "force_nocheck")
+ )
needs_isinstance = (
- needs_convert and
- dialect.returns_unicode_strings and
- self.convert_unicode != 'force_nocheck'
+ needs_convert
+ and dialect.returns_unicode_strings
+ and self.convert_unicode != "force_nocheck"
)
if needs_convert:
if needs_isinstance:
return processors.to_conditional_unicode_processor_factory(
- dialect.encoding, self.unicode_error)
+ dialect.encoding, self.unicode_error
+ )
else:
return processors.to_unicode_processor_factory(
- dialect.encoding, self.unicode_error)
+ dialect.encoding, self.unicode_error
+ )
else:
return None
@@ -301,7 +325,8 @@ class Text(String):
argument here, it will be rejected by others.
"""
- __visit_name__ = 'text'
+
+ __visit_name__ = "text"
class Unicode(String):
@@ -360,7 +385,7 @@ class Unicode(String):
"""
- __visit_name__ = 'unicode'
+ __visit_name__ = "unicode"
def __init__(self, length=None, **kwargs):
"""
@@ -371,8 +396,8 @@ class Unicode(String):
defaults to ``True``.
"""
- kwargs.setdefault('convert_unicode', True)
- kwargs.setdefault('_warn_on_bytestring', True)
+ kwargs.setdefault("convert_unicode", True)
+ kwargs.setdefault("_warn_on_bytestring", True)
super(Unicode, self).__init__(length=length, **kwargs)
@@ -389,7 +414,7 @@ class UnicodeText(Text):
"""
- __visit_name__ = 'unicode_text'
+ __visit_name__ = "unicode_text"
def __init__(self, length=None, **kwargs):
"""
@@ -400,8 +425,8 @@ class UnicodeText(Text):
defaults to ``True``.
"""
- kwargs.setdefault('convert_unicode', True)
- kwargs.setdefault('_warn_on_bytestring', True)
+ kwargs.setdefault("convert_unicode", True)
+ kwargs.setdefault("_warn_on_bytestring", True)
super(UnicodeText, self).__init__(length=length, **kwargs)
@@ -409,7 +434,7 @@ class Integer(_LookupExpressionAdapter, TypeEngine):
"""A type for ``int`` integers."""
- __visit_name__ = 'integer'
+ __visit_name__ = "integer"
def get_dbapi_type(self, dbapi):
return dbapi.NUMBER
@@ -421,6 +446,7 @@ class Integer(_LookupExpressionAdapter, TypeEngine):
def literal_processor(self, dialect):
def process(value):
return str(value)
+
return process
@util.memoized_property
@@ -438,18 +464,9 @@ class Integer(_LookupExpressionAdapter, TypeEngine):
Integer: self.__class__,
Numeric: Numeric,
},
- operators.div: {
- Integer: self.__class__,
- Numeric: Numeric,
- },
- operators.truediv: {
- Integer: self.__class__,
- Numeric: Numeric,
- },
- operators.sub: {
- Integer: self.__class__,
- Numeric: Numeric,
- },
+ operators.div: {Integer: self.__class__, Numeric: Numeric},
+ operators.truediv: {Integer: self.__class__, Numeric: Numeric},
+ operators.sub: {Integer: self.__class__, Numeric: Numeric},
}
@@ -462,7 +479,7 @@ class SmallInteger(Integer):
"""
- __visit_name__ = 'small_integer'
+ __visit_name__ = "small_integer"
class BigInteger(Integer):
@@ -474,7 +491,7 @@ class BigInteger(Integer):
"""
- __visit_name__ = 'big_integer'
+ __visit_name__ = "big_integer"
class Numeric(_LookupExpressionAdapter, TypeEngine):
@@ -517,12 +534,17 @@ class Numeric(_LookupExpressionAdapter, TypeEngine):
"""
- __visit_name__ = 'numeric'
+ __visit_name__ = "numeric"
_default_decimal_return_scale = 10
- def __init__(self, precision=None, scale=None,
- decimal_return_scale=None, asdecimal=True):
+ def __init__(
+ self,
+ precision=None,
+ scale=None,
+ decimal_return_scale=None,
+ asdecimal=True,
+ ):
"""
Construct a Numeric.
@@ -587,6 +609,7 @@ class Numeric(_LookupExpressionAdapter, TypeEngine):
def literal_processor(self, dialect):
def process(value):
return str(value)
+
return process
@property
@@ -608,19 +631,23 @@ class Numeric(_LookupExpressionAdapter, TypeEngine):
# we're a "numeric", DBAPI will give us Decimal directly
return None
else:
- util.warn('Dialect %s+%s does *not* support Decimal '
- 'objects natively, and SQLAlchemy must '
- 'convert from floating point - rounding '
- 'errors and other issues may occur. Please '
- 'consider storing Decimal numbers as strings '
- 'or integers on this platform for lossless '
- 'storage.' % (dialect.name, dialect.driver))
+ util.warn(
+ "Dialect %s+%s does *not* support Decimal "
+ "objects natively, and SQLAlchemy must "
+ "convert from floating point - rounding "
+ "errors and other issues may occur. Please "
+ "consider storing Decimal numbers as strings "
+ "or integers on this platform for lossless "
+ "storage." % (dialect.name, dialect.driver)
+ )
# we're a "numeric", DBAPI returns floats, convert.
return processors.to_decimal_processor_factory(
decimal.Decimal,
- self.scale if self.scale is not None
- else self._default_decimal_return_scale)
+ self.scale
+ if self.scale is not None
+ else self._default_decimal_return_scale,
+ )
else:
if dialect.supports_native_decimal:
return processors.to_float
@@ -635,22 +662,13 @@ class Numeric(_LookupExpressionAdapter, TypeEngine):
Numeric: self.__class__,
Integer: self.__class__,
},
- operators.div: {
- Numeric: self.__class__,
- Integer: self.__class__,
- },
+ operators.div: {Numeric: self.__class__, Integer: self.__class__},
operators.truediv: {
Numeric: self.__class__,
Integer: self.__class__,
},
- operators.add: {
- Numeric: self.__class__,
- Integer: self.__class__,
- },
- operators.sub: {
- Numeric: self.__class__,
- Integer: self.__class__,
- }
+ operators.add: {Numeric: self.__class__, Integer: self.__class__},
+ operators.sub: {Numeric: self.__class__, Integer: self.__class__},
}
@@ -675,12 +693,17 @@ class Float(Numeric):
"""
- __visit_name__ = 'float'
+ __visit_name__ = "float"
scale = None
- def __init__(self, precision=None, asdecimal=False,
- decimal_return_scale=None, **kwargs):
+ def __init__(
+ self,
+ precision=None,
+ asdecimal=False,
+ decimal_return_scale=None,
+ **kwargs
+ ):
r"""
Construct a Float.
@@ -713,14 +736,15 @@ class Float(Numeric):
self.asdecimal = asdecimal
self.decimal_return_scale = decimal_return_scale
if kwargs:
- util.warn_deprecated("Additional keyword arguments "
- "passed to Float ignored.")
+ util.warn_deprecated(
+ "Additional keyword arguments " "passed to Float ignored."
+ )
def result_processor(self, dialect, coltype):
if self.asdecimal:
return processors.to_decimal_processor_factory(
- decimal.Decimal,
- self._effective_decimal_return_scale)
+ decimal.Decimal, self._effective_decimal_return_scale
+ )
elif dialect.supports_native_decimal:
return processors.to_float
else:
@@ -746,7 +770,7 @@ class DateTime(_LookupExpressionAdapter, TypeEngine):
"""
- __visit_name__ = 'datetime'
+ __visit_name__ = "datetime"
def __init__(self, timezone=False):
"""Construct a new :class:`.DateTime`.
@@ -777,13 +801,8 @@ class DateTime(_LookupExpressionAdapter, TypeEngine):
# static/functions-datetime.html.
return {
- operators.add: {
- Interval: self.__class__,
- },
- operators.sub: {
- Interval: self.__class__,
- DateTime: Interval,
- },
+ operators.add: {Interval: self.__class__},
+ operators.sub: {Interval: self.__class__, DateTime: Interval},
}
@@ -791,7 +810,7 @@ class Date(_LookupExpressionAdapter, TypeEngine):
"""A type for ``datetime.date()`` objects."""
- __visit_name__ = 'date'
+ __visit_name__ = "date"
def get_dbapi_type(self, dbapi):
return dbapi.DATETIME
@@ -814,12 +833,9 @@ class Date(_LookupExpressionAdapter, TypeEngine):
operators.sub: {
# date - integer = date
Integer: self.__class__,
-
# date - date = integer.
Date: Integer,
-
Interval: DateTime,
-
# date - datetime = interval,
# this one is not in the PG docs
# but works
@@ -832,7 +848,7 @@ class Time(_LookupExpressionAdapter, TypeEngine):
"""A type for ``datetime.time()`` objects."""
- __visit_name__ = 'time'
+ __visit_name__ = "time"
def __init__(self, timezone=False):
self.timezone = timezone
@@ -850,14 +866,8 @@ class Time(_LookupExpressionAdapter, TypeEngine):
# static/functions-datetime.html.
return {
- operators.add: {
- Date: DateTime,
- Interval: self.__class__
- },
- operators.sub: {
- Time: Interval,
- Interval: self.__class__,
- },
+ operators.add: {Date: DateTime, Interval: self.__class__},
+ operators.sub: {Time: Interval, Interval: self.__class__},
}
@@ -872,6 +882,7 @@ class _Binary(TypeEngine):
def process(value):
value = value.decode(dialect.encoding).replace("'", "''")
return "'%s'" % value
+
return process
@property
@@ -891,14 +902,17 @@ class _Binary(TypeEngine):
return DBAPIBinary(value)
else:
return None
+
return process
# Python 3 has native bytes() type
# both sqlite3 and pg8000 seem to return it,
# psycopg2 as of 2.5 returns 'memoryview'
if util.py2k:
+
def result_processor(self, dialect, coltype):
if util.jython:
+
def process(value):
if value is not None:
if isinstance(value, array.array):
@@ -906,15 +920,19 @@ class _Binary(TypeEngine):
return str(value)
else:
return None
+
else:
process = processors.to_str
return process
+
else:
+
def result_processor(self, dialect, coltype):
def process(value):
if value is not None:
value = bytes(value)
return value
+
return process
def coerce_compared_value(self, op, value):
@@ -939,7 +957,7 @@ class LargeBinary(_Binary):
"""
- __visit_name__ = 'large_binary'
+ __visit_name__ = "large_binary"
def __init__(self, length=None):
"""
@@ -958,8 +976,9 @@ class Binary(LargeBinary):
"""Deprecated. Renamed to LargeBinary."""
def __init__(self, *arg, **kw):
- util.warn_deprecated('The Binary type has been renamed to '
- 'LargeBinary.')
+ util.warn_deprecated(
+ "The Binary type has been renamed to " "LargeBinary."
+ )
LargeBinary.__init__(self, *arg, **kw)
@@ -986,8 +1005,15 @@ class SchemaType(SchemaEventTarget):
"""
- def __init__(self, name=None, schema=None, metadata=None,
- inherit_schema=False, quote=None, _create_events=True):
+ def __init__(
+ self,
+ name=None,
+ schema=None,
+ metadata=None,
+ inherit_schema=False,
+ quote=None,
+ _create_events=True,
+ ):
if name is not None:
self.name = quoted_name(name, quote)
else:
@@ -1001,12 +1027,12 @@ class SchemaType(SchemaEventTarget):
event.listen(
self.metadata,
"before_create",
- util.portable_instancemethod(self._on_metadata_create)
+ util.portable_instancemethod(self._on_metadata_create),
)
event.listen(
self.metadata,
"after_drop",
- util.portable_instancemethod(self._on_metadata_drop)
+ util.portable_instancemethod(self._on_metadata_drop),
)
def _translate_schema(self, effective_schema, map_):
@@ -1018,7 +1044,7 @@ class SchemaType(SchemaEventTarget):
def _variant_mapping_for_set_table(self, column):
if isinstance(column.type, Variant):
variant_mapping = column.type.mapping.copy()
- variant_mapping['_default'] = column.type.impl
+ variant_mapping["_default"] = column.type.impl
else:
variant_mapping = None
return variant_mapping
@@ -1036,15 +1062,15 @@ class SchemaType(SchemaEventTarget):
table,
"before_create",
util.portable_instancemethod(
- self._on_table_create,
- {"variant_mapping": variant_mapping})
+ self._on_table_create, {"variant_mapping": variant_mapping}
+ ),
)
event.listen(
table,
"after_drop",
util.portable_instancemethod(
- self._on_table_drop,
- {"variant_mapping": variant_mapping})
+ self._on_table_drop, {"variant_mapping": variant_mapping}
+ ),
)
if self.metadata is None:
# TODO: what's the difference between self.metadata
@@ -1054,29 +1080,33 @@ class SchemaType(SchemaEventTarget):
"before_create",
util.portable_instancemethod(
self._on_metadata_create,
- {"variant_mapping": variant_mapping})
+ {"variant_mapping": variant_mapping},
+ ),
)
event.listen(
table.metadata,
"after_drop",
util.portable_instancemethod(
self._on_metadata_drop,
- {"variant_mapping": variant_mapping})
+ {"variant_mapping": variant_mapping},
+ ),
)
def copy(self, **kw):
return self.adapt(self.__class__, _create_events=True)
def adapt(self, impltype, **kw):
- schema = kw.pop('schema', self.schema)
- metadata = kw.pop('metadata', self.metadata)
- _create_events = kw.pop('_create_events', False)
- return impltype(name=self.name,
- schema=schema,
- inherit_schema=self.inherit_schema,
- metadata=metadata,
- _create_events=_create_events,
- **kw)
+ schema = kw.pop("schema", self.schema)
+ metadata = kw.pop("metadata", self.metadata)
+ _create_events = kw.pop("_create_events", False)
+ return impltype(
+ name=self.name,
+ schema=schema,
+ inherit_schema=self.inherit_schema,
+ metadata=metadata,
+ _create_events=_create_events,
+ **kw
+ )
@property
def bind(self):
@@ -1133,15 +1163,17 @@ class SchemaType(SchemaEventTarget):
t._on_metadata_drop(target, bind, **kw)
def _is_impl_for_variant(self, dialect, kw):
- variant_mapping = kw.pop('variant_mapping', None)
+ variant_mapping = kw.pop("variant_mapping", None)
if variant_mapping is None:
return True
- if dialect.name in variant_mapping and \
- variant_mapping[dialect.name] is self:
+ if (
+ dialect.name in variant_mapping
+ and variant_mapping[dialect.name] is self
+ ):
return True
elif dialect.name not in variant_mapping:
- return variant_mapping['_default'] is self
+ return variant_mapping["_default"] is self
class Enum(Emulated, String, SchemaType):
@@ -1220,7 +1252,8 @@ class Enum(Emulated, String, SchemaType):
:class:`.mysql.ENUM` - MySQL-specific type
"""
- __visit_name__ = 'enum'
+
+ __visit_name__ = "enum"
def __init__(self, *enums, **kw):
r"""Construct an enum.
@@ -1322,15 +1355,15 @@ class Enum(Emulated, String, SchemaType):
other arguments in kw to pass through.
"""
- self.native_enum = kw.pop('native_enum', True)
- self.create_constraint = kw.pop('create_constraint', True)
- self.values_callable = kw.pop('values_callable', None)
+ self.native_enum = kw.pop("native_enum", True)
+ self.create_constraint = kw.pop("create_constraint", True)
+ self.values_callable = kw.pop("values_callable", None)
values, objects = self._parse_into_values(enums, kw)
self._setup_for_values(values, objects, kw)
- convert_unicode = kw.pop('convert_unicode', None)
- self.validate_strings = kw.pop('validate_strings', False)
+ convert_unicode = kw.pop("convert_unicode", None)
+ self.validate_strings = kw.pop("validate_strings", False)
if convert_unicode is None:
for e in self.enums:
@@ -1347,33 +1380,35 @@ class Enum(Emulated, String, SchemaType):
self._valid_lookup[None] = self._object_lookup[None] = None
super(Enum, self).__init__(
- length=length,
- convert_unicode=convert_unicode,
+ length=length, convert_unicode=convert_unicode
)
if self.enum_class:
- kw.setdefault('name', self.enum_class.__name__.lower())
+ kw.setdefault("name", self.enum_class.__name__.lower())
SchemaType.__init__(
self,
- name=kw.pop('name', None),
- schema=kw.pop('schema', None),
- metadata=kw.pop('metadata', None),
- inherit_schema=kw.pop('inherit_schema', False),
- quote=kw.pop('quote', None),
- _create_events=kw.pop('_create_events', True)
+ name=kw.pop("name", None),
+ schema=kw.pop("schema", None),
+ metadata=kw.pop("metadata", None),
+ inherit_schema=kw.pop("inherit_schema", False),
+ quote=kw.pop("quote", None),
+ _create_events=kw.pop("_create_events", True),
)
def _parse_into_values(self, enums, kw):
- if not enums and '_enums' in kw:
- enums = kw.pop('_enums')
+ if not enums and "_enums" in kw:
+ enums = kw.pop("_enums")
- if len(enums) == 1 and hasattr(enums[0], '__members__'):
+ if len(enums) == 1 and hasattr(enums[0], "__members__"):
self.enum_class = enums[0]
if self.values_callable:
values = self.values_callable(self.enum_class)
else:
values = list(self.enum_class.__members__)
- objects = [self.enum_class.__members__[k] for k in self.enum_class.__members__]
+ objects = [
+ self.enum_class.__members__[k]
+ for k in self.enum_class.__members__
+ ]
return values, objects
else:
self.enum_class = None
@@ -1382,18 +1417,16 @@ class Enum(Emulated, String, SchemaType):
def _setup_for_values(self, values, objects, kw):
self.enums = list(values)
- self._valid_lookup = dict(
- zip(reversed(objects), reversed(values))
- )
+ self._valid_lookup = dict(zip(reversed(objects), reversed(values)))
- self._object_lookup = dict(
- zip(values, objects)
- )
+ self._object_lookup = dict(zip(values, objects))
- self._valid_lookup.update([
- (value, self._valid_lookup[self._object_lookup[value]])
- for value in values
- ])
+ self._valid_lookup.update(
+ [
+ (value, self._valid_lookup[self._object_lookup[value]])
+ for value in values
+ ]
+ )
@property
def native(self):
@@ -1411,22 +1444,24 @@ class Enum(Emulated, String, SchemaType):
# here between an INSERT statement and a criteria used in a SELECT,
# for now we're staying conservative w/ behavioral changes (perhaps
# someone has a trigger that handles strings on INSERT)
- if not self.validate_strings and \
- isinstance(elem, compat.string_types):
+ if not self.validate_strings and isinstance(
+ elem, compat.string_types
+ ):
return elem
else:
raise LookupError(
- '"%s" is not among the defined enum values' % elem)
+ '"%s" is not among the defined enum values' % elem
+ )
class Comparator(String.Comparator):
-
def _adapt_expression(self, op, other_comparator):
op, typ = super(Enum.Comparator, self)._adapt_expression(
- op, other_comparator)
+ op, other_comparator
+ )
if op is operators.concat_op:
typ = String(
- self.type.length,
- convert_unicode=self.type.convert_unicode)
+ self.type.length, convert_unicode=self.type.convert_unicode
+ )
return op, typ
comparator_factory = Comparator
@@ -1436,38 +1471,40 @@ class Enum(Emulated, String, SchemaType):
return self._object_lookup[elem]
except KeyError:
raise LookupError(
- '"%s" is not among the defined enum values' % elem)
+ '"%s" is not among the defined enum values' % elem
+ )
def __repr__(self):
return util.generic_repr(
self,
- additional_kw=[('native_enum', True)],
+ additional_kw=[("native_enum", True)],
to_inspect=[Enum, SchemaType],
)
def adapt_to_emulated(self, impltype, **kw):
kw.setdefault("convert_unicode", self.convert_unicode)
kw.setdefault("validate_strings", self.validate_strings)
- kw.setdefault('name', self.name)
- kw.setdefault('schema', self.schema)
- kw.setdefault('inherit_schema', self.inherit_schema)
- kw.setdefault('metadata', self.metadata)
- kw.setdefault('_create_events', False)
- kw.setdefault('native_enum', self.native_enum)
- kw.setdefault('values_callable', self.values_callable)
- kw.setdefault('create_constraint', self.create_constraint)
- assert '_enums' in kw
+ kw.setdefault("name", self.name)
+ kw.setdefault("schema", self.schema)
+ kw.setdefault("inherit_schema", self.inherit_schema)
+ kw.setdefault("metadata", self.metadata)
+ kw.setdefault("_create_events", False)
+ kw.setdefault("native_enum", self.native_enum)
+ kw.setdefault("values_callable", self.values_callable)
+ kw.setdefault("create_constraint", self.create_constraint)
+ assert "_enums" in kw
return impltype(**kw)
def adapt(self, impltype, **kw):
- kw['_enums'] = self._enums_argument
+ kw["_enums"] = self._enums_argument
return super(Enum, self).adapt(impltype, **kw)
def _should_create_constraint(self, compiler, **kw):
if not self._is_impl_for_variant(compiler.dialect, kw):
return False
- return not self.native_enum or \
- not compiler.dialect.supports_native_enum
+ return (
+ not self.native_enum or not compiler.dialect.supports_native_enum
+ )
@util.dependencies("sqlalchemy.sql.schema")
def _set_table(self, schema, column, table):
@@ -1483,20 +1520,21 @@ class Enum(Emulated, String, SchemaType):
name=_defer_name(self.name),
_create_rule=util.portable_instancemethod(
self._should_create_constraint,
- {"variant_mapping": variant_mapping}),
- _type_bound=True
+ {"variant_mapping": variant_mapping},
+ ),
+ _type_bound=True,
)
assert e.table is table
def literal_processor(self, dialect):
- parent_processor = super(
- Enum, self).literal_processor(dialect)
+ parent_processor = super(Enum, self).literal_processor(dialect)
def process(value):
value = self._db_value_for_elem(value)
if parent_processor:
value = parent_processor(value)
return value
+
return process
def bind_processor(self, dialect):
@@ -1510,8 +1548,7 @@ class Enum(Emulated, String, SchemaType):
return process
def result_processor(self, dialect, coltype):
- parent_processor = super(Enum, self).result_processor(
- dialect, coltype)
+ parent_processor = super(Enum, self).result_processor(dialect, coltype)
def process(value):
if parent_processor:
@@ -1548,8 +1585,9 @@ class PickleType(TypeDecorator):
impl = LargeBinary
- def __init__(self, protocol=pickle.HIGHEST_PROTOCOL,
- pickler=None, comparator=None):
+ def __init__(
+ self, protocol=pickle.HIGHEST_PROTOCOL, pickler=None, comparator=None
+ ):
"""
Construct a PickleType.
@@ -1570,40 +1608,46 @@ class PickleType(TypeDecorator):
super(PickleType, self).__init__()
def __reduce__(self):
- return PickleType, (self.protocol,
- None,
- self.comparator)
+ return PickleType, (self.protocol, None, self.comparator)
def bind_processor(self, dialect):
impl_processor = self.impl.bind_processor(dialect)
dumps = self.pickler.dumps
protocol = self.protocol
if impl_processor:
+
def process(value):
if value is not None:
value = dumps(value, protocol)
return impl_processor(value)
+
else:
+
def process(value):
if value is not None:
value = dumps(value, protocol)
return value
+
return process
def result_processor(self, dialect, coltype):
impl_processor = self.impl.result_processor(dialect, coltype)
loads = self.pickler.loads
if impl_processor:
+
def process(value):
value = impl_processor(value)
if value is None:
return None
return loads(value)
+
else:
+
def process(value):
if value is None:
return None
return loads(value)
+
return process
def compare_values(self, x, y):
@@ -1635,11 +1679,10 @@ class Boolean(Emulated, TypeEngine, SchemaType):
"""
- __visit_name__ = 'boolean'
+ __visit_name__ = "boolean"
native = True
- def __init__(
- self, create_constraint=True, name=None, _create_events=True):
+ def __init__(self, create_constraint=True, name=None, _create_events=True):
"""Construct a Boolean.
:param create_constraint: defaults to True. If the boolean
@@ -1657,8 +1700,10 @@ class Boolean(Emulated, TypeEngine, SchemaType):
def _should_create_constraint(self, compiler, **kw):
if not self._is_impl_for_variant(compiler.dialect, kw):
return False
- return not compiler.dialect.supports_native_boolean and \
- compiler.dialect.non_native_boolean_check_constraint
+ return (
+ not compiler.dialect.supports_native_boolean
+ and compiler.dialect.non_native_boolean_check_constraint
+ )
@util.dependencies("sqlalchemy.sql.schema")
def _set_table(self, schema, column, table):
@@ -1672,8 +1717,9 @@ class Boolean(Emulated, TypeEngine, SchemaType):
name=_defer_name(self.name),
_create_rule=util.portable_instancemethod(
self._should_create_constraint,
- {"variant_mapping": variant_mapping}),
- _type_bound=True
+ {"variant_mapping": variant_mapping},
+ ),
+ _type_bound=True,
)
assert e.table is table
@@ -1686,11 +1732,11 @@ class Boolean(Emulated, TypeEngine, SchemaType):
def _strict_as_bool(self, value):
if value not in self._strict_bools:
if not isinstance(value, int):
- raise TypeError(
- "Not a boolean value: %r" % value)
+ raise TypeError("Not a boolean value: %r" % value)
else:
raise ValueError(
- "Value %r is not None, True, or False" % value)
+ "Value %r is not None, True, or False" % value
+ )
return value
def literal_processor(self, dialect):
@@ -1700,6 +1746,7 @@ class Boolean(Emulated, TypeEngine, SchemaType):
def process(value):
return true if self._strict_as_bool(value) else false
+
return process
def bind_processor(self, dialect):
@@ -1714,6 +1761,7 @@ class Boolean(Emulated, TypeEngine, SchemaType):
if value is not None:
value = _coerce(value)
return value
+
return process
def result_processor(self, dialect, coltype):
@@ -1736,18 +1784,10 @@ class _AbstractInterval(_LookupExpressionAdapter, TypeEngine):
DateTime: DateTime,
Time: Time,
},
- operators.sub: {
- Interval: self.__class__
- },
- operators.mul: {
- Numeric: self.__class__
- },
- operators.truediv: {
- Numeric: self.__class__
- },
- operators.div: {
- Numeric: self.__class__
- }
+ operators.sub: {Interval: self.__class__},
+ operators.mul: {Numeric: self.__class__},
+ operators.truediv: {Numeric: self.__class__},
+ operators.div: {Numeric: self.__class__},
}
@property
@@ -1780,9 +1820,7 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator):
impl = DateTime
epoch = dt.datetime.utcfromtimestamp(0)
- def __init__(self, native=True,
- second_precision=None,
- day_precision=None):
+ def __init__(self, native=True, second_precision=None, day_precision=None):
"""Construct an Interval object.
:param native: when True, use the actual
@@ -1815,31 +1853,39 @@ class Interval(Emulated, _AbstractInterval, TypeDecorator):
impl_processor = self.impl.bind_processor(dialect)
epoch = self.epoch
if impl_processor:
+
def process(value):
if value is not None:
value = epoch + value
return impl_processor(value)
+
else:
+
def process(value):
if value is not None:
value = epoch + value
return value
+
return process
def result_processor(self, dialect, coltype):
impl_processor = self.impl.result_processor(dialect, coltype)
epoch = self.epoch
if impl_processor:
+
def process(value):
value = impl_processor(value)
if value is None:
return None
return value - epoch
+
else:
+
def process(value):
if value is None:
return None
return value - epoch
+
return process
@@ -1986,10 +2032,11 @@ class JSON(Indexable, TypeEngine):
"""
- __visit_name__ = 'JSON'
+
+ __visit_name__ = "JSON"
hashable = False
- NULL = util.symbol('JSON_NULL')
+ NULL = util.symbol("JSON_NULL")
"""Describe the json value of NULL.
This value is used to force the JSON value of ``"null"`` to be
@@ -2109,20 +2156,25 @@ class JSON(Indexable, TypeEngine):
class Comparator(Indexable.Comparator, Concatenable.Comparator):
"""Define comparison operations for :class:`.types.JSON`."""
- @util.dependencies('sqlalchemy.sql.default_comparator')
+ @util.dependencies("sqlalchemy.sql.default_comparator")
def _setup_getitem(self, default_comparator, index):
- if not isinstance(index, util.string_types) and \
- isinstance(index, compat.collections_abc.Sequence):
+ if not isinstance(index, util.string_types) and isinstance(
+ index, compat.collections_abc.Sequence
+ ):
index = default_comparator._check_literal(
- self.expr, operators.json_path_getitem_op,
- index, bindparam_type=JSON.JSONPathType
+ self.expr,
+ operators.json_path_getitem_op,
+ index,
+ bindparam_type=JSON.JSONPathType,
)
operator = operators.json_path_getitem_op
else:
index = default_comparator._check_literal(
- self.expr, operators.json_getitem_op,
- index, bindparam_type=JSON.JSONIndexType
+ self.expr,
+ operators.json_getitem_op,
+ index,
+ bindparam_type=JSON.JSONIndexType,
)
operator = operators.json_getitem_op
@@ -2172,6 +2224,7 @@ class JSON(Indexable, TypeEngine):
if string_process:
value = string_process(value)
return json_deserializer(value)
+
return process
@@ -2266,7 +2319,8 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
:class:`.postgresql.ARRAY`
"""
- __visit_name__ = 'ARRAY'
+
+ __visit_name__ = "ARRAY"
zero_indexes = False
"""if True, Python zero-based indexes should be interpreted as one-based
@@ -2285,21 +2339,23 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
if isinstance(index, slice):
return_type = self.type
if self.type.zero_indexes:
- index = slice(
- index.start + 1,
- index.stop + 1,
- index.step
- )
+ index = slice(index.start + 1, index.stop + 1, index.step)
index = Slice(
_literal_as_binds(
- index.start, name=self.expr.key,
- type_=type_api.INTEGERTYPE),
+ index.start,
+ name=self.expr.key,
+ type_=type_api.INTEGERTYPE,
+ ),
_literal_as_binds(
- index.stop, name=self.expr.key,
- type_=type_api.INTEGERTYPE),
+ index.stop,
+ name=self.expr.key,
+ type_=type_api.INTEGERTYPE,
+ ),
_literal_as_binds(
- index.step, name=self.expr.key,
- type_=type_api.INTEGERTYPE)
+ index.step,
+ name=self.expr.key,
+ type_=type_api.INTEGERTYPE,
+ ),
)
else:
if self.type.zero_indexes:
@@ -2307,16 +2363,18 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
if self.type.dimensions is None or self.type.dimensions == 1:
return_type = self.type.item_type
else:
- adapt_kw = {'dimensions': self.type.dimensions - 1}
+ adapt_kw = {"dimensions": self.type.dimensions - 1}
return_type = self.type.adapt(
- self.type.__class__, **adapt_kw)
+ self.type.__class__, **adapt_kw
+ )
return operators.getitem, index, return_type
def contains(self, *arg, **kw):
raise NotImplementedError(
"ARRAY.contains() not implemented for the base "
- "ARRAY type; please use the dialect-specific ARRAY type")
+ "ARRAY type; please use the dialect-specific ARRAY type"
+ )
@util.dependencies("sqlalchemy.sql.elements")
def any(self, elements, other, operator=None):
@@ -2350,7 +2408,7 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
operator = operator if operator else operators.eq
return operator(
elements._literal_as_binds(other),
- elements.CollectionAggregate._create_any(self.expr)
+ elements.CollectionAggregate._create_any(self.expr),
)
@util.dependencies("sqlalchemy.sql.elements")
@@ -2385,13 +2443,14 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
operator = operator if operator else operators.eq
return operator(
elements._literal_as_binds(other),
- elements.CollectionAggregate._create_all(self.expr)
+ elements.CollectionAggregate._create_all(self.expr),
)
comparator_factory = Comparator
- def __init__(self, item_type, as_tuple=False, dimensions=None,
- zero_indexes=False):
+ def __init__(
+ self, item_type, as_tuple=False, dimensions=None, zero_indexes=False
+ ):
"""Construct an :class:`.types.ARRAY`.
E.g.::
@@ -2424,8 +2483,10 @@ class ARRAY(SchemaEventTarget, Indexable, Concatenable, TypeEngine):
"""
if isinstance(item_type, ARRAY):
- raise ValueError("Do not nest ARRAY types; ARRAY(basetype) "
- "handles multi-dimensional arrays of basetype")
+ raise ValueError(
+ "Do not nest ARRAY types; ARRAY(basetype) "
+ "handles multi-dimensional arrays of basetype"
+ )
if isinstance(item_type, type):
item_type = item_type()
self.item_type = item_type
@@ -2463,35 +2524,37 @@ class REAL(Float):
"""The SQL REAL type."""
- __visit_name__ = 'REAL'
+ __visit_name__ = "REAL"
class FLOAT(Float):
"""The SQL FLOAT type."""
- __visit_name__ = 'FLOAT'
+ __visit_name__ = "FLOAT"
class NUMERIC(Numeric):
"""The SQL NUMERIC type."""
- __visit_name__ = 'NUMERIC'
+ __visit_name__ = "NUMERIC"
class DECIMAL(Numeric):
"""The SQL DECIMAL type."""
- __visit_name__ = 'DECIMAL'
+ __visit_name__ = "DECIMAL"
class INTEGER(Integer):
"""The SQL INT or INTEGER type."""
- __visit_name__ = 'INTEGER'
+ __visit_name__ = "INTEGER"
+
+
INT = INTEGER
@@ -2499,14 +2562,14 @@ class SMALLINT(SmallInteger):
"""The SQL SMALLINT type."""
- __visit_name__ = 'SMALLINT'
+ __visit_name__ = "SMALLINT"
class BIGINT(BigInteger):
"""The SQL BIGINT type."""
- __visit_name__ = 'BIGINT'
+ __visit_name__ = "BIGINT"
class TIMESTAMP(DateTime):
@@ -2520,7 +2583,7 @@ class TIMESTAMP(DateTime):
"""
- __visit_name__ = 'TIMESTAMP'
+ __visit_name__ = "TIMESTAMP"
def __init__(self, timezone=False):
"""Construct a new :class:`.TIMESTAMP`.
@@ -2543,28 +2606,28 @@ class DATETIME(DateTime):
"""The SQL DATETIME type."""
- __visit_name__ = 'DATETIME'
+ __visit_name__ = "DATETIME"
class DATE(Date):
"""The SQL DATE type."""
- __visit_name__ = 'DATE'
+ __visit_name__ = "DATE"
class TIME(Time):
"""The SQL TIME type."""
- __visit_name__ = 'TIME'
+ __visit_name__ = "TIME"
class TEXT(Text):
"""The SQL TEXT type."""
- __visit_name__ = 'TEXT'
+ __visit_name__ = "TEXT"
class CLOB(Text):
@@ -2574,63 +2637,63 @@ class CLOB(Text):
This type is found in Oracle and Informix.
"""
- __visit_name__ = 'CLOB'
+ __visit_name__ = "CLOB"
class VARCHAR(String):
"""The SQL VARCHAR type."""
- __visit_name__ = 'VARCHAR'
+ __visit_name__ = "VARCHAR"
class NVARCHAR(Unicode):
"""The SQL NVARCHAR type."""
- __visit_name__ = 'NVARCHAR'
+ __visit_name__ = "NVARCHAR"
class CHAR(String):
"""The SQL CHAR type."""
- __visit_name__ = 'CHAR'
+ __visit_name__ = "CHAR"
class NCHAR(Unicode):
"""The SQL NCHAR type."""
- __visit_name__ = 'NCHAR'
+ __visit_name__ = "NCHAR"
class BLOB(LargeBinary):
"""The SQL BLOB type."""
- __visit_name__ = 'BLOB'
+ __visit_name__ = "BLOB"
class BINARY(_Binary):
"""The SQL BINARY type."""
- __visit_name__ = 'BINARY'
+ __visit_name__ = "BINARY"
class VARBINARY(_Binary):
"""The SQL VARBINARY type."""
- __visit_name__ = 'VARBINARY'
+ __visit_name__ = "VARBINARY"
class BOOLEAN(Boolean):
"""The SQL BOOLEAN type."""
- __visit_name__ = 'BOOLEAN'
+ __visit_name__ = "BOOLEAN"
class NullType(TypeEngine):
@@ -2657,7 +2720,8 @@ class NullType(TypeEngine):
construct.
"""
- __visit_name__ = 'null'
+
+ __visit_name__ = "null"
_isnull = True
@@ -2666,16 +2730,18 @@ class NullType(TypeEngine):
def literal_processor(self, dialect):
def process(value):
return "NULL"
+
return process
class Comparator(TypeEngine.Comparator):
-
def _adapt_expression(self, op, other_comparator):
- if isinstance(other_comparator, NullType.Comparator) or \
- not operators.is_commutative(op):
+ if isinstance(
+ other_comparator, NullType.Comparator
+ ) or not operators.is_commutative(op):
return op, self.expr.type
else:
return other_comparator._adapt_expression(op, self)
+
comparator_factory = Comparator
@@ -2694,6 +2760,7 @@ class MatchType(Boolean):
"""
+
NULLTYPE = NullType()
BOOLEANTYPE = Boolean()
STRINGTYPE = String()
@@ -2709,7 +2776,7 @@ _type_map = {
dt.datetime: DateTime(),
dt.time: Time(),
dt.timedelta: Interval(),
- util.NoneType: NULLTYPE
+ util.NoneType: NULLTYPE,
}
if util.py3k:
@@ -2729,19 +2796,23 @@ def _resolve_value_to_type(value):
# objects.
insp = inspection.inspect(value, False)
if (
- insp is not None and
- # foil mock.Mock() and other impostors by ensuring
- # the inspection target itself self-inspects
- insp.__class__ in inspection._registrars
+ insp is not None
+ and
+ # foil mock.Mock() and other impostors by ensuring
+ # the inspection target itself self-inspects
+ insp.__class__ in inspection._registrars
):
raise exc.ArgumentError(
- "Object %r is not legal as a SQL literal value" % value)
+ "Object %r is not legal as a SQL literal value" % value
+ )
return NULLTYPE
else:
return _result_type
+
# back-assign to type_api
from . import type_api
+
type_api.BOOLEANTYPE = BOOLEANTYPE
type_api.STRINGTYPE = STRINGTYPE
type_api.INTEGERTYPE = INTEGERTYPE
diff --git a/lib/sqlalchemy/sql/type_api.py b/lib/sqlalchemy/sql/type_api.py
index a8dfa19be..7fe780783 100644
--- a/lib/sqlalchemy/sql/type_api.py
+++ b/lib/sqlalchemy/sql/type_api.py
@@ -49,7 +49,8 @@ class TypeEngine(Visitable):
"""
- __slots__ = 'expr', 'type'
+
+ __slots__ = "expr", "type"
default_comparator = None
@@ -57,16 +58,15 @@ class TypeEngine(Visitable):
self.expr = expr
self.type = expr.type
- @util.dependencies('sqlalchemy.sql.default_comparator')
+ @util.dependencies("sqlalchemy.sql.default_comparator")
def operate(self, default_comparator, op, *other, **kwargs):
o = default_comparator.operator_lookup[op.__name__]
return o[0](self.expr, op, *(other + o[1:]), **kwargs)
- @util.dependencies('sqlalchemy.sql.default_comparator')
+ @util.dependencies("sqlalchemy.sql.default_comparator")
def reverse_operate(self, default_comparator, op, other, **kwargs):
o = default_comparator.operator_lookup[op.__name__]
- return o[0](self.expr, op, other,
- reverse=True, *o[1:], **kwargs)
+ return o[0](self.expr, op, other, reverse=True, *o[1:], **kwargs)
def _adapt_expression(self, op, other_comparator):
"""evaluate the return type of <self> <op> <othertype>,
@@ -97,7 +97,7 @@ class TypeEngine(Visitable):
return op, self.type
def __reduce__(self):
- return _reconstitute_comparator, (self.expr, )
+ return _reconstitute_comparator, (self.expr,)
hashable = True
"""Flag, if False, means values from this type aren't hashable.
@@ -313,8 +313,10 @@ class TypeEngine(Visitable):
"""
- return self.__class__.column_expression.__code__ \
+ return (
+ self.__class__.column_expression.__code__
is not TypeEngine.column_expression.__code__
+ )
def bind_expression(self, bindvalue):
""""Given a bind value (i.e. a :class:`.BindParameter` instance),
@@ -351,8 +353,10 @@ class TypeEngine(Visitable):
"""
- return self.__class__.bind_expression.__code__ \
+ return (
+ self.__class__.bind_expression.__code__
is not TypeEngine.bind_expression.__code__
+ )
@staticmethod
def _to_instance(cls_or_self):
@@ -441,9 +445,9 @@ class TypeEngine(Visitable):
"""
try:
- return dialect._type_memos[self]['impl']
+ return dialect._type_memos[self]["impl"]
except KeyError:
- return self._dialect_info(dialect)['impl']
+ return self._dialect_info(dialect)["impl"]
def _unwrapped_dialect_impl(self, dialect):
"""Return the 'unwrapped' dialect impl for this type.
@@ -462,20 +466,20 @@ class TypeEngine(Visitable):
def _cached_literal_processor(self, dialect):
"""Return a dialect-specific literal processor for this type."""
try:
- return dialect._type_memos[self]['literal']
+ return dialect._type_memos[self]["literal"]
except KeyError:
d = self._dialect_info(dialect)
- d['literal'] = lp = d['impl'].literal_processor(dialect)
+ d["literal"] = lp = d["impl"].literal_processor(dialect)
return lp
def _cached_bind_processor(self, dialect):
"""Return a dialect-specific bind processor for this type."""
try:
- return dialect._type_memos[self]['bind']
+ return dialect._type_memos[self]["bind"]
except KeyError:
d = self._dialect_info(dialect)
- d['bind'] = bp = d['impl'].bind_processor(dialect)
+ d["bind"] = bp = d["impl"].bind_processor(dialect)
return bp
def _cached_result_processor(self, dialect, coltype):
@@ -488,7 +492,7 @@ class TypeEngine(Visitable):
# key assumption: DBAPI type codes are
# constants. Else this dictionary would
# grow unbounded.
- d[coltype] = rp = d['impl'].result_processor(dialect, coltype)
+ d[coltype] = rp = d["impl"].result_processor(dialect, coltype)
return rp
def _cached_custom_processor(self, dialect, key, fn):
@@ -496,7 +500,7 @@ class TypeEngine(Visitable):
return dialect._type_memos[self][key]
except KeyError:
d = self._dialect_info(dialect)
- impl = d['impl']
+ impl = d["impl"]
d[key] = result = fn(impl)
return result
@@ -513,7 +517,7 @@ class TypeEngine(Visitable):
impl = self.adapt(type(self))
# this can't be self, else we create a cycle
assert impl is not self
- dialect._type_memos[self] = d = {'impl': impl}
+ dialect._type_memos[self] = d = {"impl": impl}
return d
def _gen_dialect_impl(self, dialect):
@@ -549,8 +553,10 @@ class TypeEngine(Visitable):
"""
_coerced_type = _resolve_value_to_type(value)
- if _coerced_type is NULLTYPE or _coerced_type._type_affinity \
- is self._type_affinity:
+ if (
+ _coerced_type is NULLTYPE
+ or _coerced_type._type_affinity is self._type_affinity
+ ):
return self
else:
return _coerced_type
@@ -586,8 +592,7 @@ class TypeEngine(Visitable):
def __str__(self):
if util.py2k:
- return unicode(self.compile()).\
- encode('ascii', 'backslashreplace')
+ return unicode(self.compile()).encode("ascii", "backslashreplace")
else:
return str(self.compile())
@@ -645,15 +650,16 @@ class UserDefinedType(util.with_metaclass(VisitableCheckKWArg, TypeEngine)):
``type_expression``, if it receives ``**kw`` in its signature.
"""
+
__visit_name__ = "user_defined"
- ensure_kwarg = 'get_col_spec'
+ ensure_kwarg = "get_col_spec"
class Comparator(TypeEngine.Comparator):
__slots__ = ()
def _adapt_expression(self, op, other_comparator):
- if hasattr(self.type, 'adapt_operator'):
+ if hasattr(self.type, "adapt_operator"):
util.warn_deprecated(
"UserDefinedType.adapt_operator is deprecated. Create "
"a UserDefinedType.Comparator subclass instead which "
@@ -854,6 +860,7 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
will cause the index value ``'foo'`` to be JSON encoded.
"""
+
__visit_name__ = "type_decorator"
def __init__(self, *args, **kwargs):
@@ -874,14 +881,16 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
"""
- if not hasattr(self.__class__, 'impl'):
- raise AssertionError("TypeDecorator implementations "
- "require a class-level variable "
- "'impl' which refers to the class of "
- "type being decorated")
+ if not hasattr(self.__class__, "impl"):
+ raise AssertionError(
+ "TypeDecorator implementations "
+ "require a class-level variable "
+ "'impl' which refers to the class of "
+ "type being decorated"
+ )
self.impl = to_instance(self.__class__.impl, *args, **kwargs)
- coerce_to_is_types = (util.NoneType, )
+ coerce_to_is_types = (util.NoneType,)
"""Specify those Python types which should be coerced at the expression
level to "IS <constant>" when compared using ``==`` (and same for
``IS NOT`` in conjunction with ``!=``.
@@ -906,24 +915,27 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
__slots__ = ()
def operate(self, op, *other, **kwargs):
- kwargs['_python_is_types'] = self.expr.type.coerce_to_is_types
+ kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
return super(TypeDecorator.Comparator, self).operate(
- op, *other, **kwargs)
+ op, *other, **kwargs
+ )
def reverse_operate(self, op, other, **kwargs):
- kwargs['_python_is_types'] = self.expr.type.coerce_to_is_types
+ kwargs["_python_is_types"] = self.expr.type.coerce_to_is_types
return super(TypeDecorator.Comparator, self).reverse_operate(
- op, other, **kwargs)
+ op, other, **kwargs
+ )
@property
def comparator_factory(self):
if TypeDecorator.Comparator in self.impl.comparator_factory.__mro__:
return self.impl.comparator_factory
else:
- return type("TDComparator",
- (TypeDecorator.Comparator,
- self.impl.comparator_factory),
- {})
+ return type(
+ "TDComparator",
+ (TypeDecorator.Comparator, self.impl.comparator_factory),
+ {},
+ )
def _gen_dialect_impl(self, dialect):
"""
@@ -939,10 +951,11 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
typedesc = self._unwrapped_dialect_impl(dialect)
tt = self.copy()
if not isinstance(tt, self.__class__):
- raise AssertionError('Type object %s does not properly '
- 'implement the copy() method, it must '
- 'return an object of type %s' %
- (self, self.__class__))
+ raise AssertionError(
+ "Type object %s does not properly "
+ "implement the copy() method, it must "
+ "return an object of type %s" % (self, self.__class__)
+ )
tt.impl = typedesc
return tt
@@ -1099,8 +1112,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
"""
- return self.__class__.process_bind_param.__code__ \
+ return (
+ self.__class__.process_bind_param.__code__
is not TypeDecorator.process_bind_param.__code__
+ )
@util.memoized_property
def _has_literal_processor(self):
@@ -1109,8 +1124,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
"""
- return self.__class__.process_literal_param.__code__ \
+ return (
+ self.__class__.process_literal_param.__code__
is not TypeDecorator.process_literal_param.__code__
+ )
def literal_processor(self, dialect):
"""Provide a literal processing function for the given
@@ -1147,9 +1164,12 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
if process_param:
impl_processor = self.impl.literal_processor(dialect)
if impl_processor:
+
def process(value):
return impl_processor(process_param(value, dialect))
+
else:
+
def process(value):
return process_param(value, dialect)
@@ -1180,10 +1200,12 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
process_param = self.process_bind_param
impl_processor = self.impl.bind_processor(dialect)
if impl_processor:
+
def process(value):
return impl_processor(process_param(value, dialect))
else:
+
def process(value):
return process_param(value, dialect)
@@ -1200,8 +1222,10 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
exception throw.
"""
- return self.__class__.process_result_value.__code__ \
+ return (
+ self.__class__.process_result_value.__code__
is not TypeDecorator.process_result_value.__code__
+ )
def result_processor(self, dialect, coltype):
"""Provide a result value processing function for the given
@@ -1225,13 +1249,14 @@ class TypeDecorator(SchemaEventTarget, TypeEngine):
"""
if self._has_result_processor:
process_value = self.process_result_value
- impl_processor = self.impl.result_processor(dialect,
- coltype)
+ impl_processor = self.impl.result_processor(dialect, coltype)
if impl_processor:
+
def process(value):
return process_value(impl_processor(value), dialect)
else:
+
def process(value):
return process_value(value, dialect)
@@ -1397,7 +1422,8 @@ class Variant(TypeDecorator):
if dialect_name in self.mapping:
raise exc.ArgumentError(
"Dialect '%s' is already present in "
- "the mapping for this Variant" % dialect_name)
+ "the mapping for this Variant" % dialect_name
+ )
mapping = self.mapping.copy()
mapping[dialect_name] = type_
return Variant(self.impl, mapping)
@@ -1439,6 +1465,6 @@ def adapt_type(typeobj, colspecs):
# but it turns out the originally given "generic" type
# is actually a subclass of our resulting type, then we were already
# given a more specific type than that required; so use that.
- if (issubclass(typeobj.__class__, impltype)):
+ if issubclass(typeobj.__class__, impltype):
return typeobj
return typeobj.adapt(impltype)
diff --git a/lib/sqlalchemy/sql/util.py b/lib/sqlalchemy/sql/util.py
index 12cfe09d1..4feaf9938 100644
--- a/lib/sqlalchemy/sql/util.py
+++ b/lib/sqlalchemy/sql/util.py
@@ -15,15 +15,29 @@ from . import operators, visitors
from itertools import chain
from collections import deque
-from .elements import BindParameter, ColumnClause, ColumnElement, \
- Null, UnaryExpression, literal_column, Label, _label_reference, \
- _textual_label_reference
-from .selectable import SelectBase, ScalarSelect, Join, FromClause, FromGrouping
+from .elements import (
+ BindParameter,
+ ColumnClause,
+ ColumnElement,
+ Null,
+ UnaryExpression,
+ literal_column,
+ Label,
+ _label_reference,
+ _textual_label_reference,
+)
+from .selectable import (
+ SelectBase,
+ ScalarSelect,
+ Join,
+ FromClause,
+ FromGrouping,
+)
from .schema import Column
join_condition = util.langhelpers.public_factory(
- Join._join_condition,
- ".sql.util.join_condition")
+ Join._join_condition, ".sql.util.join_condition"
+)
# names that are still being imported from the outside
from .annotation import _shallow_annotate, _deep_annotate, _deep_deannotate
@@ -88,8 +102,9 @@ def find_left_clause_that_matches_given(clauses, join_from):
for idx in liberal_idx:
f = clauses[idx]
for s in selectables:
- if set(surface_selectables(f)).\
- intersection(surface_selectables(s)):
+ if set(surface_selectables(f)).intersection(
+ surface_selectables(s)
+ ):
conservative_idx.append(idx)
break
if conservative_idx:
@@ -184,8 +199,9 @@ def visit_binary_product(fn, expr):
# we don't want to dig into correlated subqueries,
# those are just column elements by themselves
yield element
- elif element.__visit_name__ == 'binary' and \
- operators.is_comparison(element.operator):
+ elif element.__visit_name__ == "binary" and operators.is_comparison(
+ element.operator
+ ):
stack.insert(0, element)
for l in visit(element.left):
for r in visit(element.right):
@@ -199,38 +215,47 @@ def visit_binary_product(fn, expr):
for elem in element.get_children():
for e in visit(elem):
yield e
+
list(visit(expr))
-def find_tables(clause, check_columns=False,
- include_aliases=False, include_joins=False,
- include_selects=False, include_crud=False):
+def find_tables(
+ clause,
+ check_columns=False,
+ include_aliases=False,
+ include_joins=False,
+ include_selects=False,
+ include_crud=False,
+):
"""locate Table objects within the given expression."""
tables = []
_visitors = {}
if include_selects:
- _visitors['select'] = _visitors['compound_select'] = tables.append
+ _visitors["select"] = _visitors["compound_select"] = tables.append
if include_joins:
- _visitors['join'] = tables.append
+ _visitors["join"] = tables.append
if include_aliases:
- _visitors['alias'] = tables.append
+ _visitors["alias"] = tables.append
if include_crud:
- _visitors['insert'] = _visitors['update'] = \
- _visitors['delete'] = lambda ent: tables.append(ent.table)
+ _visitors["insert"] = _visitors["update"] = _visitors[
+ "delete"
+ ] = lambda ent: tables.append(ent.table)
if check_columns:
+
def visit_column(column):
tables.append(column.table)
- _visitors['column'] = visit_column
- _visitors['table'] = tables.append
+ _visitors["column"] = visit_column
- visitors.traverse(clause, {'column_collections': False}, _visitors)
+ _visitors["table"] = tables.append
+
+ visitors.traverse(clause, {"column_collections": False}, _visitors)
return tables
@@ -243,10 +268,9 @@ def unwrap_order_by(clause):
stack = deque([clause])
while stack:
t = stack.popleft()
- if isinstance(t, ColumnElement) and \
- (
- not isinstance(t, UnaryExpression) or
- not operators.is_ordering_modifier(t.modifier)
+ if isinstance(t, ColumnElement) and (
+ not isinstance(t, UnaryExpression)
+ or not operators.is_ordering_modifier(t.modifier)
):
if isinstance(t, _label_reference):
t = t.element
@@ -266,9 +290,7 @@ def unwrap_label_reference(element):
if isinstance(elem, (_label_reference, _textual_label_reference)):
return elem.element
- return visitors.replacement_traverse(
- element, {}, replace
- )
+ return visitors.replacement_traverse(element, {}, replace)
def expand_column_list_from_order_by(collist, order_by):
@@ -278,17 +300,16 @@ def expand_column_list_from_order_by(collist, order_by):
in the collist.
"""
- cols_already_present = set([
- col.element if col._order_by_label_element is not None
- else col for col in collist
- ])
+ cols_already_present = set(
+ [
+ col.element if col._order_by_label_element is not None else col
+ for col in collist
+ ]
+ )
return [
- col for col in
- chain(*[
- unwrap_order_by(o)
- for o in order_by
- ])
+ col
+ for col in chain(*[unwrap_order_by(o) for o in order_by])
if col not in cols_already_present
]
@@ -325,9 +346,9 @@ def surface_column_elements(clause, include_scalar_selects=True):
be addressable in the WHERE clause of a SELECT if this element were
in the columns clause."""
- filter_ = (FromGrouping, )
+ filter_ = (FromGrouping,)
if not include_scalar_selects:
- filter_ += (SelectBase, )
+ filter_ += (SelectBase,)
stack = deque([clause])
while stack:
@@ -343,9 +364,7 @@ def selectables_overlap(left, right):
"""Return True if left/right have some overlapping selectable"""
return bool(
- set(surface_selectables(left)).intersection(
- surface_selectables(right)
- )
+ set(surface_selectables(left)).intersection(surface_selectables(right))
)
@@ -366,7 +385,7 @@ def bind_values(clause):
def visit_bindparam(bind):
v.append(bind.effective_value)
- visitors.traverse(clause, {}, {'bindparam': visit_bindparam})
+ visitors.traverse(clause, {}, {"bindparam": visit_bindparam})
return v
@@ -383,7 +402,7 @@ class _repr_base(object):
_TUPLE = 1
_DICT = 2
- __slots__ = 'max_chars',
+ __slots__ = ("max_chars",)
def trunc(self, value):
rep = repr(value)
@@ -391,10 +410,12 @@ class _repr_base(object):
if lenrep > self.max_chars:
segment_length = self.max_chars // 2
rep = (
- rep[0:segment_length] +
- (" ... (%d characters truncated) ... "
- % (lenrep - self.max_chars)) +
- rep[-segment_length:]
+ rep[0:segment_length]
+ + (
+ " ... (%d characters truncated) ... "
+ % (lenrep - self.max_chars)
+ )
+ + rep[-segment_length:]
)
return rep
@@ -402,7 +423,7 @@ class _repr_base(object):
class _repr_row(_repr_base):
"""Provide a string view of a row."""
- __slots__ = 'row',
+ __slots__ = ("row",)
def __init__(self, row, max_chars=300):
self.row = row
@@ -412,7 +433,7 @@ class _repr_row(_repr_base):
trunc = self.trunc
return "(%s%s)" % (
", ".join(trunc(value) for value in self.row),
- "," if len(self.row) == 1 else ""
+ "," if len(self.row) == 1 else "",
)
@@ -424,7 +445,7 @@ class _repr_params(_repr_base):
"""
- __slots__ = 'params', 'batches',
+ __slots__ = "params", "batches"
def __init__(self, params, batches, max_chars=300):
self.params = params
@@ -435,11 +456,13 @@ class _repr_params(_repr_base):
if isinstance(self.params, list):
typ = self._LIST
ismulti = self.params and isinstance(
- self.params[0], (list, dict, tuple))
+ self.params[0], (list, dict, tuple)
+ )
elif isinstance(self.params, tuple):
typ = self._TUPLE
ismulti = self.params and isinstance(
- self.params[0], (list, dict, tuple))
+ self.params[0], (list, dict, tuple)
+ )
elif isinstance(self.params, dict):
typ = self._DICT
ismulti = False
@@ -448,11 +471,15 @@ class _repr_params(_repr_base):
if ismulti and len(self.params) > self.batches:
msg = " ... displaying %i of %i total bound parameter sets ... "
- return ' '.join((
- self._repr_multi(self.params[:self.batches - 2], typ)[0:-1],
- msg % (self.batches, len(self.params)),
- self._repr_multi(self.params[-2:], typ)[1:]
- ))
+ return " ".join(
+ (
+ self._repr_multi(self.params[: self.batches - 2], typ)[
+ 0:-1
+ ],
+ msg % (self.batches, len(self.params)),
+ self._repr_multi(self.params[-2:], typ)[1:],
+ )
+ )
elif ismulti:
return self._repr_multi(self.params, typ)
else:
@@ -467,12 +494,13 @@ class _repr_params(_repr_base):
elif isinstance(multi_params[0], dict):
elem_type = self._DICT
else:
- assert False, \
- "Unknown parameter type %s" % (type(multi_params[0]))
+ assert False, "Unknown parameter type %s" % (
+ type(multi_params[0])
+ )
elements = ", ".join(
- self._repr_params(params, elem_type)
- for params in multi_params)
+ self._repr_params(params, elem_type) for params in multi_params
+ )
else:
elements = ""
@@ -493,13 +521,10 @@ class _repr_params(_repr_base):
elif typ is self._TUPLE:
return "(%s%s)" % (
", ".join(trunc(value) for value in params),
- "," if len(params) == 1 else ""
-
+ "," if len(params) == 1 else "",
)
else:
- return "[%s]" % (
- ", ".join(trunc(value) for value in params)
- )
+ return "[%s]" % (", ".join(trunc(value) for value in params))
def adapt_criterion_to_null(crit, nulls):
@@ -509,20 +534,24 @@ def adapt_criterion_to_null(crit, nulls):
"""
def visit_binary(binary):
- if isinstance(binary.left, BindParameter) \
- and binary.left._identifying_key in nulls:
+ if (
+ isinstance(binary.left, BindParameter)
+ and binary.left._identifying_key in nulls
+ ):
# reverse order if the NULL is on the left side
binary.left = binary.right
binary.right = Null()
binary.operator = operators.is_
binary.negate = operators.isnot
- elif isinstance(binary.right, BindParameter) \
- and binary.right._identifying_key in nulls:
+ elif (
+ isinstance(binary.right, BindParameter)
+ and binary.right._identifying_key in nulls
+ ):
binary.right = Null()
binary.operator = operators.is_
binary.negate = operators.isnot
- return visitors.cloned_traverse(crit, {}, {'binary': visit_binary})
+ return visitors.cloned_traverse(crit, {}, {"binary": visit_binary})
def splice_joins(left, right, stop_on=None):
@@ -570,8 +599,8 @@ def reduce_columns(columns, *clauses, **kw):
in the selectable to just those that are not repeated.
"""
- ignore_nonexistent_tables = kw.pop('ignore_nonexistent_tables', False)
- only_synonyms = kw.pop('only_synonyms', False)
+ ignore_nonexistent_tables = kw.pop("ignore_nonexistent_tables", False)
+ only_synonyms = kw.pop("only_synonyms", False)
columns = util.ordered_column_set(columns)
@@ -597,39 +626,48 @@ def reduce_columns(columns, *clauses, **kw):
continue
else:
raise
- if fk_col.shares_lineage(c) and \
- (not only_synonyms or
- c.name == col.name):
+ if fk_col.shares_lineage(c) and (
+ not only_synonyms or c.name == col.name
+ ):
omit.add(col)
break
if clauses:
+
def visit_binary(binary):
if binary.operator == operators.eq:
cols = util.column_set(
- chain(*[c.proxy_set for c in columns.difference(omit)]))
+ chain(*[c.proxy_set for c in columns.difference(omit)])
+ )
if binary.left in cols and binary.right in cols:
for c in reversed(columns):
- if c.shares_lineage(binary.right) and \
- (not only_synonyms or
- c.name == binary.left.name):
+ if c.shares_lineage(binary.right) and (
+ not only_synonyms or c.name == binary.left.name
+ ):
omit.add(c)
break
+
for clause in clauses:
if clause is not None:
- visitors.traverse(clause, {}, {'binary': visit_binary})
+ visitors.traverse(clause, {}, {"binary": visit_binary})
return ColumnSet(columns.difference(omit))
-def criterion_as_pairs(expression, consider_as_foreign_keys=None,
- consider_as_referenced_keys=None, any_operator=False):
+def criterion_as_pairs(
+ expression,
+ consider_as_foreign_keys=None,
+ consider_as_referenced_keys=None,
+ any_operator=False,
+):
"""traverse an expression and locate binary criterion pairs."""
if consider_as_foreign_keys and consider_as_referenced_keys:
- raise exc.ArgumentError("Can only specify one of "
- "'consider_as_foreign_keys' or "
- "'consider_as_referenced_keys'")
+ raise exc.ArgumentError(
+ "Can only specify one of "
+ "'consider_as_foreign_keys' or "
+ "'consider_as_referenced_keys'"
+ )
def col_is(a, b):
# return a is b
@@ -638,37 +676,44 @@ def criterion_as_pairs(expression, consider_as_foreign_keys=None,
def visit_binary(binary):
if not any_operator and binary.operator is not operators.eq:
return
- if not isinstance(binary.left, ColumnElement) or \
- not isinstance(binary.right, ColumnElement):
+ if not isinstance(binary.left, ColumnElement) or not isinstance(
+ binary.right, ColumnElement
+ ):
return
if consider_as_foreign_keys:
- if binary.left in consider_as_foreign_keys and \
- (col_is(binary.right, binary.left) or
- binary.right not in consider_as_foreign_keys):
+ if binary.left in consider_as_foreign_keys and (
+ col_is(binary.right, binary.left)
+ or binary.right not in consider_as_foreign_keys
+ ):
pairs.append((binary.right, binary.left))
- elif binary.right in consider_as_foreign_keys and \
- (col_is(binary.left, binary.right) or
- binary.left not in consider_as_foreign_keys):
+ elif binary.right in consider_as_foreign_keys and (
+ col_is(binary.left, binary.right)
+ or binary.left not in consider_as_foreign_keys
+ ):
pairs.append((binary.left, binary.right))
elif consider_as_referenced_keys:
- if binary.left in consider_as_referenced_keys and \
- (col_is(binary.right, binary.left) or
- binary.right not in consider_as_referenced_keys):
+ if binary.left in consider_as_referenced_keys and (
+ col_is(binary.right, binary.left)
+ or binary.right not in consider_as_referenced_keys
+ ):
pairs.append((binary.left, binary.right))
- elif binary.right in consider_as_referenced_keys and \
- (col_is(binary.left, binary.right) or
- binary.left not in consider_as_referenced_keys):
+ elif binary.right in consider_as_referenced_keys and (
+ col_is(binary.left, binary.right)
+ or binary.left not in consider_as_referenced_keys
+ ):
pairs.append((binary.right, binary.left))
else:
- if isinstance(binary.left, Column) and \
- isinstance(binary.right, Column):
+ if isinstance(binary.left, Column) and isinstance(
+ binary.right, Column
+ ):
if binary.left.references(binary.right):
pairs.append((binary.right, binary.left))
elif binary.right.references(binary.left):
pairs.append((binary.left, binary.right))
+
pairs = []
- visitors.traverse(expression, {}, {'binary': visit_binary})
+ visitors.traverse(expression, {}, {"binary": visit_binary})
return pairs
@@ -699,28 +744,38 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
"""
- def __init__(self, selectable, equivalents=None,
- include_fn=None, exclude_fn=None,
- adapt_on_names=False, anonymize_labels=False):
+ def __init__(
+ self,
+ selectable,
+ equivalents=None,
+ include_fn=None,
+ exclude_fn=None,
+ adapt_on_names=False,
+ anonymize_labels=False,
+ ):
self.__traverse_options__ = {
- 'stop_on': [selectable],
- 'anonymize_labels': anonymize_labels}
+ "stop_on": [selectable],
+ "anonymize_labels": anonymize_labels,
+ }
self.selectable = selectable
self.include_fn = include_fn
self.exclude_fn = exclude_fn
self.equivalents = util.column_dict(equivalents or {})
self.adapt_on_names = adapt_on_names
- def _corresponding_column(self, col, require_embedded,
- _seen=util.EMPTY_SET):
+ def _corresponding_column(
+ self, col, require_embedded, _seen=util.EMPTY_SET
+ ):
newcol = self.selectable.corresponding_column(
- col,
- require_embedded=require_embedded)
+ col, require_embedded=require_embedded
+ )
if newcol is None and col in self.equivalents and col not in _seen:
for equiv in self.equivalents[col]:
newcol = self._corresponding_column(
- equiv, require_embedded=require_embedded,
- _seen=_seen.union([col]))
+ equiv,
+ require_embedded=require_embedded,
+ _seen=_seen.union([col]),
+ )
if newcol is not None:
return newcol
if self.adapt_on_names and newcol is None:
@@ -728,8 +783,9 @@ class ClauseAdapter(visitors.ReplacingCloningVisitor):
return newcol
def replace(self, col):
- if isinstance(col, FromClause) and \
- self.selectable.is_derived_from(col):
+ if isinstance(col, FromClause) and self.selectable.is_derived_from(
+ col
+ ):
return self.selectable
elif not isinstance(col, ColumnElement):
return None
@@ -772,16 +828,27 @@ class ColumnAdapter(ClauseAdapter):
"""
- def __init__(self, selectable, equivalents=None,
- chain_to=None, adapt_required=False,
- include_fn=None, exclude_fn=None,
- adapt_on_names=False,
- allow_label_resolve=True,
- anonymize_labels=False):
- ClauseAdapter.__init__(self, selectable, equivalents,
- include_fn=include_fn, exclude_fn=exclude_fn,
- adapt_on_names=adapt_on_names,
- anonymize_labels=anonymize_labels)
+ def __init__(
+ self,
+ selectable,
+ equivalents=None,
+ chain_to=None,
+ adapt_required=False,
+ include_fn=None,
+ exclude_fn=None,
+ adapt_on_names=False,
+ allow_label_resolve=True,
+ anonymize_labels=False,
+ ):
+ ClauseAdapter.__init__(
+ self,
+ selectable,
+ equivalents,
+ include_fn=include_fn,
+ exclude_fn=exclude_fn,
+ adapt_on_names=adapt_on_names,
+ anonymize_labels=anonymize_labels,
+ )
if chain_to:
self.chain(chain_to)
@@ -800,9 +867,7 @@ class ColumnAdapter(ClauseAdapter):
def __getitem__(self, key):
if (
self.parent.include_fn and not self.parent.include_fn(key)
- ) or (
- self.parent.exclude_fn and self.parent.exclude_fn(key)
- ):
+ ) or (self.parent.exclude_fn and self.parent.exclude_fn(key)):
if self.parent._wrap:
return self.parent._wrap.columns[key]
else:
@@ -843,7 +908,7 @@ class ColumnAdapter(ClauseAdapter):
def __getstate__(self):
d = self.__dict__.copy()
- del d['columns']
+ del d["columns"]
return d
def __setstate__(self, state):
diff --git a/lib/sqlalchemy/sql/visitors.py b/lib/sqlalchemy/sql/visitors.py
index b39ec8167..bf1743643 100644
--- a/lib/sqlalchemy/sql/visitors.py
+++ b/lib/sqlalchemy/sql/visitors.py
@@ -29,11 +29,20 @@ from .. import util
import operator
from .. import exc
-__all__ = ['VisitableType', 'Visitable', 'ClauseVisitor',
- 'CloningVisitor', 'ReplacingCloningVisitor', 'iterate',
- 'iterate_depthfirst', 'traverse_using', 'traverse',
- 'traverse_depthfirst',
- 'cloned_traverse', 'replacement_traverse']
+__all__ = [
+ "VisitableType",
+ "Visitable",
+ "ClauseVisitor",
+ "CloningVisitor",
+ "ReplacingCloningVisitor",
+ "iterate",
+ "iterate_depthfirst",
+ "traverse_using",
+ "traverse",
+ "traverse_depthfirst",
+ "cloned_traverse",
+ "replacement_traverse",
+]
class VisitableType(type):
@@ -53,8 +62,7 @@ class VisitableType(type):
"""
def __init__(cls, clsname, bases, clsdict):
- if clsname != 'Visitable' and \
- hasattr(cls, '__visit_name__'):
+ if clsname != "Visitable" and hasattr(cls, "__visit_name__"):
_generate_dispatch(cls)
super(VisitableType, cls).__init__(clsname, bases, clsdict)
@@ -64,7 +72,7 @@ def _generate_dispatch(cls):
"""Return an optimized visit dispatch function for the cls
for use by the compiler.
"""
- if '__visit_name__' in cls.__dict__:
+ if "__visit_name__" in cls.__dict__:
visit_name = cls.__visit_name__
if isinstance(visit_name, str):
# There is an optimization opportunity here because the
@@ -79,12 +87,13 @@ def _generate_dispatch(cls):
raise exc.UnsupportedCompilationError(visitor, cls)
else:
return meth(self, **kw)
+
else:
# The optimization opportunity is lost for this case because the
# __visit_name__ is not yet a string. As a result, the visit
# string has to be recalculated with each compilation.
def _compiler_dispatch(self, visitor, **kw):
- visit_attr = 'visit_%s' % self.__visit_name__
+ visit_attr = "visit_%s" % self.__visit_name__
try:
meth = getattr(visitor, visit_attr)
except AttributeError:
@@ -92,8 +101,7 @@ def _generate_dispatch(cls):
else:
return meth(self, **kw)
- _compiler_dispatch.__doc__ = \
- """Look for an attribute named "visit_" + self.__visit_name__
+ _compiler_dispatch.__doc__ = """Look for an attribute named "visit_" + self.__visit_name__
on the visitor, and call it with the same kw params.
"""
cls._compiler_dispatch = _compiler_dispatch
@@ -137,7 +145,7 @@ class ClauseVisitor(object):
visitors = {}
for name in dir(self):
- if name.startswith('visit_'):
+ if name.startswith("visit_"):
visitors[name[6:]] = getattr(self, name)
return visitors
@@ -148,7 +156,7 @@ class ClauseVisitor(object):
v = self
while v:
yield v
- v = getattr(v, '_next', None)
+ v = getattr(v, "_next", None)
def chain(self, visitor):
"""'chain' an additional ClauseVisitor onto this ClauseVisitor.
@@ -178,7 +186,8 @@ class CloningVisitor(ClauseVisitor):
"""traverse and visit the given expression structure."""
return cloned_traverse(
- obj, self.__traverse_options__, self._visitor_dict)
+ obj, self.__traverse_options__, self._visitor_dict
+ )
class ReplacingCloningVisitor(CloningVisitor):
@@ -204,6 +213,7 @@ class ReplacingCloningVisitor(CloningVisitor):
e = v.replace(elem)
if e is not None:
return e
+
return replacement_traverse(obj, self.__traverse_options__, replace)
@@ -282,7 +292,7 @@ def cloned_traverse(obj, opts, visitors):
modifications by visitors."""
cloned = {}
- stop_on = set(opts.get('stop_on', []))
+ stop_on = set(opts.get("stop_on", []))
def clone(elem):
if elem in stop_on:
@@ -306,11 +316,13 @@ def replacement_traverse(obj, opts, replace):
replacement by a given replacement function."""
cloned = {}
- stop_on = {id(x) for x in opts.get('stop_on', [])}
+ stop_on = {id(x) for x in opts.get("stop_on", [])}
def clone(elem, **kw):
- if id(elem) in stop_on or \
- 'no_replacement_traverse' in elem._annotations:
+ if (
+ id(elem) in stop_on
+ or "no_replacement_traverse" in elem._annotations
+ ):
return elem
else:
newelem = replace(elem)