summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql')
-rw-r--r--lib/sqlalchemy/sql/compiler.py74
-rw-r--r--lib/sqlalchemy/sql/elements.py87
-rw-r--r--lib/sqlalchemy/sql/expression.py1
-rw-r--r--lib/sqlalchemy/sql/schema.py8
-rw-r--r--lib/sqlalchemy/sql/selectable.py32
5 files changed, 150 insertions, 52 deletions
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index b703c59f2..15ddd7d6f 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -139,8 +139,16 @@ RESERVED_WORDS = set(
)
LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I)
+LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I)
ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"])
+FK_ON_DELETE = re.compile(
+ r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
+)
+FK_ON_UPDATE = re.compile(
+ r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
+)
+FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I)
BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
@@ -758,12 +766,11 @@ class SQLCompiler(Compiled):
else:
col = with_cols[element.element]
except KeyError:
- # treat it like text()
- util.warn_limited(
- "Can't resolve label reference %r; converting to text()",
- util.ellipses_string(element.element),
+ elements._no_text_coercion(
+ element.element,
+ exc.CompileError,
+ "Can't resolve label reference for ORDER BY / GROUP BY.",
)
- return self.process(element._text_clause)
else:
kwargs["render_label_as_label"] = col
return self.process(
@@ -1076,10 +1083,24 @@ class SQLCompiler(Compiled):
if func._has_args:
name += "%(expr)s"
else:
- name = func.name + "%(expr)s"
- return ".".join(list(func.packagenames) + [name]) % {
- "expr": self.function_argspec(func, **kwargs)
- }
+ name = func.name
+ name = (
+ self.preparer.quote(name)
+ if self.preparer._requires_quotes_illegal_chars(name)
+ else name
+ )
+ name = name + "%(expr)s"
+ return ".".join(
+ [
+ (
+ self.preparer.quote(tok)
+ if self.preparer._requires_quotes_illegal_chars(tok)
+ else tok
+ )
+ for tok in func.packagenames
+ ]
+ + [name]
+ ) % {"expr": self.function_argspec(func, **kwargs)}
def visit_next_value_func(self, next_value, **kw):
return self.visit_sequence(next_value.sequence)
@@ -3153,9 +3174,13 @@ class DDLCompiler(Compiled):
def define_constraint_cascades(self, constraint):
text = ""
if constraint.ondelete is not None:
- text += " ON DELETE %s" % constraint.ondelete
+ text += " ON DELETE %s" % self.preparer.validate_sql_phrase(
+ constraint.ondelete, FK_ON_DELETE
+ )
if constraint.onupdate is not None:
- text += " ON UPDATE %s" % constraint.onupdate
+ text += " ON UPDATE %s" % self.preparer.validate_sql_phrase(
+ constraint.onupdate, FK_ON_UPDATE
+ )
return text
def define_constraint_deferrability(self, constraint):
@@ -3166,7 +3191,9 @@ class DDLCompiler(Compiled):
else:
text += " NOT DEFERRABLE"
if constraint.initially is not None:
- text += " INITIALLY %s" % constraint.initially
+ text += " INITIALLY %s" % self.preparer.validate_sql_phrase(
+ constraint.initially, FK_INITIALLY
+ )
return text
def define_constraint_match(self, constraint):
@@ -3416,6 +3443,24 @@ class IdentifierPreparer(object):
return value.replace(self.escape_to_quote, self.escape_quote)
+ def validate_sql_phrase(self, element, reg):
+ """keyword sequence filter.
+
+ a filter for elements that are intended to represent keyword sequences,
+ such as "INITIALLY", "INTIALLY DEFERRED", etc. no special characters
+ should be present.
+
+ .. versionadded:: 1.3
+
+ """
+
+ if element is not None and not reg.match(element):
+ raise exc.CompileError(
+ "Unexpected SQL phrase: %r (matching against %r)"
+ % (element, reg.pattern)
+ )
+ return element
+
def quote_identifier(self, value):
"""Quote an identifier.
@@ -3439,6 +3484,11 @@ class IdentifierPreparer(object):
or (lc_value != value)
)
+ def _requires_quotes_illegal_chars(self, value):
+ """Return True if the given identifier requires quoting, but
+ not taking case convention into account."""
+ return not self.legal_characters.match(util.text_type(value))
+
def quote_schema(self, schema, force=None):
"""Conditionally quote a schema name.
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 9e4f5d95d..a4623128f 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -37,6 +37,20 @@ def _clone(element, **kw):
return element._clone()
+def _document_text_coercion(paramname, meth_rst, param_rst):
+ return util.add_parameter_text(
+ paramname,
+ (
+ ".. warning:: "
+ "The %s argument to %s can be passed as a Python string argument, "
+ "which will be treated "
+ "as **trusted SQL text** and rendered as given. **DO NOT PASS "
+ "UNTRUSTED INPUT TO THIS PARAMETER**."
+ )
+ % (param_rst, meth_rst),
+ )
+
+
def collate(expression, collation):
"""Return the clause ``expression COLLATE collation``.
@@ -1343,6 +1357,7 @@ class TextClause(Executable, ClauseElement):
"refer to the :meth:`.TextClause.columns` method.",
),
)
+ @_document_text_coercion("text", ":func:`.text`", ":paramref:`.text.text`")
def _create_text(
self, text, bind=None, bindparams=None, typemap=None, autocommit=None
):
@@ -4430,32 +4445,64 @@ def _literal_and_labels_as_label_reference(element):
def _expression_literal_as_text(element):
- return _literal_as_text(element, warn=True)
+ return _literal_as_text(element)
-def _literal_as_text(element, warn=False):
+def _literal_as(element, text_fallback):
if isinstance(element, Visitable):
return element
elif hasattr(element, "__clause_element__"):
return element.__clause_element__()
elif isinstance(element, util.string_types):
- if warn:
- util.warn_limited(
- "Textual SQL expression %(expr)r should be "
- "explicitly declared as text(%(expr)r)",
- {"expr": util.ellipses_string(element)},
- )
-
- return TextClause(util.text_type(element))
+ return text_fallback(element)
elif isinstance(element, (util.NoneType, bool)):
return _const_expr(element)
else:
raise exc.ArgumentError(
- "SQL expression object or string expected, got object of type %r "
+ "SQL expression object expected, got object of type %r "
"instead" % type(element)
)
+def _literal_as_text(element, allow_coercion_to_text=False):
+ if allow_coercion_to_text:
+ return _literal_as(element, TextClause)
+ else:
+ return _literal_as(element, _no_text_coercion)
+
+
+def _literal_as_column(element):
+ return _literal_as(element, ColumnClause)
+
+
+def _no_column_coercion(element):
+ element = str(element)
+ guess_is_literal = not _guess_straight_column.match(element)
+ raise exc.ArgumentError(
+ "Textual column expression %(column)r should be "
+ "explicitly declared with text(%(column)r), "
+ "or use %(literal_column)s(%(column)r) "
+ "for more specificity"
+ % {
+ "column": util.ellipses_string(element),
+ "literal_column": "literal_column"
+ if guess_is_literal
+ else "column",
+ }
+ )
+
+
+def _no_text_coercion(element, exc_cls=exc.ArgumentError, extra=None):
+ raise exc_cls(
+ "%(extra)sTextual SQL expression %(expr)r should be "
+ "explicitly declared as text(%(expr)r)"
+ % {
+ "expr": util.ellipses_string(element),
+ "extra": "%s " % extra if extra else "",
+ }
+ )
+
+
def _no_literals(element):
if hasattr(element, "__clause_element__"):
return element.__clause_element__()
@@ -4529,23 +4576,7 @@ def _interpret_as_column_or_from(element):
elif isinstance(element, (numbers.Number)):
return ColumnClause(str(element), is_literal=True)
else:
- element = str(element)
- # give into temptation, as this fact we are guessing about
- # is not one we've previously ever needed our users tell us;
- # but let them know we are not happy about it
- guess_is_literal = not _guess_straight_column.match(element)
- util.warn_limited(
- "Textual column expression %(column)r should be "
- "explicitly declared with text(%(column)r), "
- "or use %(literal_column)s(%(column)r) "
- "for more specificity",
- {
- "column": util.ellipses_string(element),
- "literal_column": "literal_column"
- if guess_is_literal
- else "column",
- },
- )
+ _no_column_coercion(element)
return ColumnClause(element, is_literal=guess_is_literal)
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 2a27d0b73..82fe93029 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -101,6 +101,7 @@ from .elements import _expression_literal_as_text # noqa
from .elements import _is_column # noqa
from .elements import _labeled # noqa
from .elements import _literal_as_binds # noqa
+from .elements import _literal_as_column # noqa
from .elements import _literal_as_label_reference # noqa
from .elements import _literal_as_text # noqa
from .elements import _only_column_elements # noqa
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 8997e119f..e981d7aed 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -42,6 +42,7 @@ from .base import ColumnCollection
from .base import DialectKWArgs
from .base import SchemaEventTarget
from .elements import _as_truncated
+from .elements import _document_text_coercion
from .elements import _literal_as_text
from .elements import ClauseElement
from .elements import ColumnClause
@@ -2884,6 +2885,11 @@ class CheckConstraint(ColumnCollectionConstraint):
_allow_multiple_tables = True
+ @_document_text_coercion(
+ "sqltext",
+ ":class:`.CheckConstraint`",
+ ":paramref:`.CheckConstraint.sqltext`",
+ )
def __init__(
self,
sqltext,
@@ -2925,7 +2931,7 @@ class CheckConstraint(ColumnCollectionConstraint):
"""
- self.sqltext = _literal_as_text(sqltext, warn=False)
+ self.sqltext = _literal_as_text(sqltext, allow_coercion_to_text=True)
columns = []
visitors.traverse(self.sqltext, {}, {"column": columns.append})
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index a5dee068c..ac08604f5 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -31,11 +31,13 @@ from .elements import _clause_element_as_expr
from .elements import _clone
from .elements import _cloned_difference
from .elements import _cloned_intersection
+from .elements import _document_text_coercion
from .elements import _expand_cloned
from .elements import _interpret_as_column_or_from
from .elements import _literal_and_labels_as_label_reference
from .elements import _literal_as_label_reference
from .elements import _literal_as_text
+from .elements import _no_text_coercion
from .elements import _select_iterables
from .elements import and_
from .elements import BindParameter
@@ -43,7 +45,6 @@ from .elements import ClauseElement
from .elements import ClauseList
from .elements import Grouping
from .elements import literal_column
-from .elements import TextClause
from .elements import True_
from .elements import UnaryExpression
from .. import exc
@@ -55,14 +56,7 @@ def _interpret_as_from(element):
insp = inspection.inspect(element, raiseerr=False)
if insp is None:
if isinstance(element, util.string_types):
- util.warn_limited(
- "Textual SQL FROM expression %(expr)r should be "
- "explicitly declared as text(%(expr)r), "
- "or use table(%(expr)r) for more specificity",
- {"expr": util.ellipses_string(element)},
- )
-
- return TextClause(util.text_type(element))
+ _no_text_coercion(element)
try:
return insp.selectable
except AttributeError:
@@ -266,6 +260,11 @@ class HasPrefixes(object):
_prefixes = ()
@_generative
+ @_document_text_coercion(
+ "expr",
+ ":meth:`.HasPrefixes.prefix_with`",
+ ":paramref:`.HasPrefixes.prefix_with.*expr`",
+ )
def prefix_with(self, *expr, **kw):
r"""Add one or more expressions following the statement keyword, i.e.
SELECT, INSERT, UPDATE, or DELETE. Generative.
@@ -297,7 +296,10 @@ class HasPrefixes(object):
def _setup_prefixes(self, prefixes, dialect=None):
self._prefixes = self._prefixes + tuple(
- [(_literal_as_text(p, warn=False), dialect) for p in prefixes]
+ [
+ (_literal_as_text(p, allow_coercion_to_text=True), dialect)
+ for p in prefixes
+ ]
)
@@ -305,6 +307,11 @@ class HasSuffixes(object):
_suffixes = ()
@_generative
+ @_document_text_coercion(
+ "expr",
+ ":meth:`.HasSuffixes.suffix_with`",
+ ":paramref:`.HasSuffixes.suffix_with.*expr`",
+ )
def suffix_with(self, *expr, **kw):
r"""Add one or more expressions following the statement as a whole.
@@ -335,7 +342,10 @@ class HasSuffixes(object):
def _setup_suffixes(self, suffixes, dialect=None):
self._suffixes = self._suffixes + tuple(
- [(_literal_as_text(p, warn=False), dialect) for p in suffixes]
+ [
+ (_literal_as_text(p, allow_coercion_to_text=True), dialect)
+ for p in suffixes
+ ]
)