summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/__init__.py2
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py11
-rw-r--r--lib/sqlalchemy/dialects/postgresql/ext.py38
-rw-r--r--lib/sqlalchemy/orm/session.py4
-rw-r--r--lib/sqlalchemy/sql/compiler.py74
-rw-r--r--lib/sqlalchemy/sql/elements.py87
-rw-r--r--lib/sqlalchemy/sql/expression.py1
-rw-r--r--lib/sqlalchemy/sql/schema.py8
-rw-r--r--lib/sqlalchemy/sql/selectable.py32
-rw-r--r--lib/sqlalchemy/util/__init__.py1
-rw-r--r--lib/sqlalchemy/util/deprecations.py64
-rw-r--r--lib/sqlalchemy/util/langhelpers.py80
12 files changed, 263 insertions, 139 deletions
diff --git a/lib/sqlalchemy/__init__.py b/lib/sqlalchemy/__init__.py
index 5fda721fe..33a0e4af2 100644
--- a/lib/sqlalchemy/__init__.py
+++ b/lib/sqlalchemy/__init__.py
@@ -122,7 +122,7 @@ from .engine import create_engine # noqa nosort
from .engine import engine_from_config # noqa nosort
-__version__ = '1.3.0b3'
+__version__ = "1.3.0b3"
def __go(lcls):
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 4004a2b9a..4d302dabe 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -948,6 +948,8 @@ except ImportError:
_python_UUID = None
+IDX_USING = re.compile(r"^(?:btree|hash|gist|gin|[\w_]+)$", re.I)
+
AUTOCOMMIT_REGEXP = re.compile(
r"\s*(?:UPDATE|INSERT|CREATE|DELETE|DROP|ALTER|GRANT|REVOKE|"
"IMPORT FOREIGN SCHEMA|REFRESH MATERIALIZED VIEW|TRUNCATE)",
@@ -1908,7 +1910,10 @@ class PGDDLCompiler(compiler.DDLCompiler):
using = index.dialect_options["postgresql"]["using"]
if using:
- text += "USING %s " % preparer.quote(using)
+ text += (
+ "USING %s "
+ % self.preparer.validate_sql_phrase(using, IDX_USING).lower()
+ )
ops = index.dialect_options["postgresql"]["ops"]
text += "(%s)" % (
@@ -1983,7 +1988,9 @@ class PGDDLCompiler(compiler.DDLCompiler):
"%s WITH %s" % (self.sql_compiler.process(expr, **kw), op)
)
text += "EXCLUDE USING %s (%s)" % (
- constraint.using,
+ self.preparer.validate_sql_phrase(
+ constraint.using, IDX_USING
+ ).lower(),
", ".join(elements),
)
if constraint.where is not None:
diff --git a/lib/sqlalchemy/dialects/postgresql/ext.py b/lib/sqlalchemy/dialects/postgresql/ext.py
index 49b5e0ec0..426028239 100644
--- a/lib/sqlalchemy/dialects/postgresql/ext.py
+++ b/lib/sqlalchemy/dialects/postgresql/ext.py
@@ -91,6 +91,11 @@ class ExcludeConstraint(ColumnCollectionConstraint):
where = None
+ @elements._document_text_coercion(
+ "where",
+ ":class:`.ExcludeConstraint`",
+ ":paramref:`.ExcludeConstraint.where`",
+ )
def __init__(self, *elements, **kw):
r"""
Create an :class:`.ExcludeConstraint` object.
@@ -123,21 +128,15 @@ class ExcludeConstraint(ColumnCollectionConstraint):
)
:param \*elements:
+
A sequence of two tuples of the form ``(column, operator)`` where
"column" is a SQL expression element or a raw SQL string, most
- typically a :class:`.Column` object,
- and "operator" is a string containing the operator to use.
-
- .. note::
-
- A plain string passed for the value of "column" is interpreted
- as an arbitrary SQL expression; when passing a plain string,
- any necessary quoting and escaping syntaxes must be applied
- manually. In order to specify a column name when a
- :class:`.Column` object is not available, while ensuring that
- any necessary quoting rules take effect, an ad-hoc
- :class:`.Column` or :func:`.sql.expression.column` object may
- be used.
+ typically a :class:`.Column` object, and "operator" is a string
+ containing the operator to use. In order to specify a column name
+ when a :class:`.Column` object is not available, while ensuring
+ that any necessary quoting rules take effect, an ad-hoc
+ :class:`.Column` or :func:`.sql.expression.column` object should be
+ used.
:param name:
Optional, the in-database name of this constraint.
@@ -159,12 +158,6 @@ class ExcludeConstraint(ColumnCollectionConstraint):
If set, emit WHERE <predicate> when issuing DDL
for this constraint.
- .. note::
-
- A plain string passed here is interpreted as an arbitrary SQL
- expression; when passing a plain string, any necessary quoting
- and escaping syntaxes must be applied manually.
-
"""
columns = []
render_exprs = []
@@ -184,11 +177,12 @@ class ExcludeConstraint(ColumnCollectionConstraint):
# backwards compat
self.operators[name] = operator
- expr = expression._literal_as_text(expr)
+ expr = expression._literal_as_column(expr)
render_exprs.append((expr, name, operator))
self._render_exprs = render_exprs
+
ColumnCollectionConstraint.__init__(
self,
*columns,
@@ -199,7 +193,9 @@ class ExcludeConstraint(ColumnCollectionConstraint):
self.using = kw.get("using", "gist")
where = kw.get("where")
if where is not None:
- self.where = expression._literal_as_text(where)
+ self.where = expression._literal_as_text(
+ where, allow_coercion_to_text=True
+ )
def copy(self, **kw):
elements = [(col, self.operators[col]) for col in self.columns.keys()]
diff --git a/lib/sqlalchemy/orm/session.py b/lib/sqlalchemy/orm/session.py
index 9e52ef208..6d4198a4e 100644
--- a/lib/sqlalchemy/orm/session.py
+++ b/lib/sqlalchemy/orm/session.py
@@ -1257,7 +1257,9 @@ class Session(_SessionClassMethods):
in order to execute the statement.
"""
- clause = expression._literal_as_text(clause)
+ clause = expression._literal_as_text(
+ clause, allow_coercion_to_text=True
+ )
if bind is None:
bind = self.get_bind(mapper, clause=clause, **kw)
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index b703c59f2..15ddd7d6f 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -139,8 +139,16 @@ RESERVED_WORDS = set(
)
LEGAL_CHARACTERS = re.compile(r"^[A-Z0-9_$]+$", re.I)
+LEGAL_CHARACTERS_PLUS_SPACE = re.compile(r"^[A-Z0-9_ $]+$", re.I)
ILLEGAL_INITIAL_CHARACTERS = {str(x) for x in range(0, 10)}.union(["$"])
+FK_ON_DELETE = re.compile(
+ r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
+)
+FK_ON_UPDATE = re.compile(
+ r"^(?:RESTRICT|CASCADE|SET NULL|NO ACTION|SET DEFAULT)$", re.I
+)
+FK_INITIALLY = re.compile(r"^(?:DEFERRED|IMMEDIATE)$", re.I)
BIND_PARAMS = re.compile(r"(?<![:\w\$\x5c]):([\w\$]+)(?![:\w\$])", re.UNICODE)
BIND_PARAMS_ESC = re.compile(r"\x5c(:[\w\$]*)(?![:\w\$])", re.UNICODE)
@@ -758,12 +766,11 @@ class SQLCompiler(Compiled):
else:
col = with_cols[element.element]
except KeyError:
- # treat it like text()
- util.warn_limited(
- "Can't resolve label reference %r; converting to text()",
- util.ellipses_string(element.element),
+ elements._no_text_coercion(
+ element.element,
+ exc.CompileError,
+ "Can't resolve label reference for ORDER BY / GROUP BY.",
)
- return self.process(element._text_clause)
else:
kwargs["render_label_as_label"] = col
return self.process(
@@ -1076,10 +1083,24 @@ class SQLCompiler(Compiled):
if func._has_args:
name += "%(expr)s"
else:
- name = func.name + "%(expr)s"
- return ".".join(list(func.packagenames) + [name]) % {
- "expr": self.function_argspec(func, **kwargs)
- }
+ name = func.name
+ name = (
+ self.preparer.quote(name)
+ if self.preparer._requires_quotes_illegal_chars(name)
+ else name
+ )
+ name = name + "%(expr)s"
+ return ".".join(
+ [
+ (
+ self.preparer.quote(tok)
+ if self.preparer._requires_quotes_illegal_chars(tok)
+ else tok
+ )
+ for tok in func.packagenames
+ ]
+ + [name]
+ ) % {"expr": self.function_argspec(func, **kwargs)}
def visit_next_value_func(self, next_value, **kw):
return self.visit_sequence(next_value.sequence)
@@ -3153,9 +3174,13 @@ class DDLCompiler(Compiled):
def define_constraint_cascades(self, constraint):
text = ""
if constraint.ondelete is not None:
- text += " ON DELETE %s" % constraint.ondelete
+ text += " ON DELETE %s" % self.preparer.validate_sql_phrase(
+ constraint.ondelete, FK_ON_DELETE
+ )
if constraint.onupdate is not None:
- text += " ON UPDATE %s" % constraint.onupdate
+ text += " ON UPDATE %s" % self.preparer.validate_sql_phrase(
+ constraint.onupdate, FK_ON_UPDATE
+ )
return text
def define_constraint_deferrability(self, constraint):
@@ -3166,7 +3191,9 @@ class DDLCompiler(Compiled):
else:
text += " NOT DEFERRABLE"
if constraint.initially is not None:
- text += " INITIALLY %s" % constraint.initially
+ text += " INITIALLY %s" % self.preparer.validate_sql_phrase(
+ constraint.initially, FK_INITIALLY
+ )
return text
def define_constraint_match(self, constraint):
@@ -3416,6 +3443,24 @@ class IdentifierPreparer(object):
return value.replace(self.escape_to_quote, self.escape_quote)
+ def validate_sql_phrase(self, element, reg):
+ """keyword sequence filter.
+
+ a filter for elements that are intended to represent keyword sequences,
+ such as "INITIALLY", "INTIALLY DEFERRED", etc. no special characters
+ should be present.
+
+ .. versionadded:: 1.3
+
+ """
+
+ if element is not None and not reg.match(element):
+ raise exc.CompileError(
+ "Unexpected SQL phrase: %r (matching against %r)"
+ % (element, reg.pattern)
+ )
+ return element
+
def quote_identifier(self, value):
"""Quote an identifier.
@@ -3439,6 +3484,11 @@ class IdentifierPreparer(object):
or (lc_value != value)
)
+ def _requires_quotes_illegal_chars(self, value):
+ """Return True if the given identifier requires quoting, but
+ not taking case convention into account."""
+ return not self.legal_characters.match(util.text_type(value))
+
def quote_schema(self, schema, force=None):
"""Conditionally quote a schema name.
diff --git a/lib/sqlalchemy/sql/elements.py b/lib/sqlalchemy/sql/elements.py
index 9e4f5d95d..a4623128f 100644
--- a/lib/sqlalchemy/sql/elements.py
+++ b/lib/sqlalchemy/sql/elements.py
@@ -37,6 +37,20 @@ def _clone(element, **kw):
return element._clone()
+def _document_text_coercion(paramname, meth_rst, param_rst):
+ return util.add_parameter_text(
+ paramname,
+ (
+ ".. warning:: "
+ "The %s argument to %s can be passed as a Python string argument, "
+ "which will be treated "
+ "as **trusted SQL text** and rendered as given. **DO NOT PASS "
+ "UNTRUSTED INPUT TO THIS PARAMETER**."
+ )
+ % (param_rst, meth_rst),
+ )
+
+
def collate(expression, collation):
"""Return the clause ``expression COLLATE collation``.
@@ -1343,6 +1357,7 @@ class TextClause(Executable, ClauseElement):
"refer to the :meth:`.TextClause.columns` method.",
),
)
+ @_document_text_coercion("text", ":func:`.text`", ":paramref:`.text.text`")
def _create_text(
self, text, bind=None, bindparams=None, typemap=None, autocommit=None
):
@@ -4430,32 +4445,64 @@ def _literal_and_labels_as_label_reference(element):
def _expression_literal_as_text(element):
- return _literal_as_text(element, warn=True)
+ return _literal_as_text(element)
-def _literal_as_text(element, warn=False):
+def _literal_as(element, text_fallback):
if isinstance(element, Visitable):
return element
elif hasattr(element, "__clause_element__"):
return element.__clause_element__()
elif isinstance(element, util.string_types):
- if warn:
- util.warn_limited(
- "Textual SQL expression %(expr)r should be "
- "explicitly declared as text(%(expr)r)",
- {"expr": util.ellipses_string(element)},
- )
-
- return TextClause(util.text_type(element))
+ return text_fallback(element)
elif isinstance(element, (util.NoneType, bool)):
return _const_expr(element)
else:
raise exc.ArgumentError(
- "SQL expression object or string expected, got object of type %r "
+ "SQL expression object expected, got object of type %r "
"instead" % type(element)
)
+def _literal_as_text(element, allow_coercion_to_text=False):
+ if allow_coercion_to_text:
+ return _literal_as(element, TextClause)
+ else:
+ return _literal_as(element, _no_text_coercion)
+
+
+def _literal_as_column(element):
+ return _literal_as(element, ColumnClause)
+
+
+def _no_column_coercion(element):
+ element = str(element)
+ guess_is_literal = not _guess_straight_column.match(element)
+ raise exc.ArgumentError(
+ "Textual column expression %(column)r should be "
+ "explicitly declared with text(%(column)r), "
+ "or use %(literal_column)s(%(column)r) "
+ "for more specificity"
+ % {
+ "column": util.ellipses_string(element),
+ "literal_column": "literal_column"
+ if guess_is_literal
+ else "column",
+ }
+ )
+
+
+def _no_text_coercion(element, exc_cls=exc.ArgumentError, extra=None):
+ raise exc_cls(
+ "%(extra)sTextual SQL expression %(expr)r should be "
+ "explicitly declared as text(%(expr)r)"
+ % {
+ "expr": util.ellipses_string(element),
+ "extra": "%s " % extra if extra else "",
+ }
+ )
+
+
def _no_literals(element):
if hasattr(element, "__clause_element__"):
return element.__clause_element__()
@@ -4529,23 +4576,7 @@ def _interpret_as_column_or_from(element):
elif isinstance(element, (numbers.Number)):
return ColumnClause(str(element), is_literal=True)
else:
- element = str(element)
- # give into temptation, as this fact we are guessing about
- # is not one we've previously ever needed our users tell us;
- # but let them know we are not happy about it
- guess_is_literal = not _guess_straight_column.match(element)
- util.warn_limited(
- "Textual column expression %(column)r should be "
- "explicitly declared with text(%(column)r), "
- "or use %(literal_column)s(%(column)r) "
- "for more specificity",
- {
- "column": util.ellipses_string(element),
- "literal_column": "literal_column"
- if guess_is_literal
- else "column",
- },
- )
+ _no_column_coercion(element)
return ColumnClause(element, is_literal=guess_is_literal)
diff --git a/lib/sqlalchemy/sql/expression.py b/lib/sqlalchemy/sql/expression.py
index 2a27d0b73..82fe93029 100644
--- a/lib/sqlalchemy/sql/expression.py
+++ b/lib/sqlalchemy/sql/expression.py
@@ -101,6 +101,7 @@ from .elements import _expression_literal_as_text # noqa
from .elements import _is_column # noqa
from .elements import _labeled # noqa
from .elements import _literal_as_binds # noqa
+from .elements import _literal_as_column # noqa
from .elements import _literal_as_label_reference # noqa
from .elements import _literal_as_text # noqa
from .elements import _only_column_elements # noqa
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 8997e119f..e981d7aed 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -42,6 +42,7 @@ from .base import ColumnCollection
from .base import DialectKWArgs
from .base import SchemaEventTarget
from .elements import _as_truncated
+from .elements import _document_text_coercion
from .elements import _literal_as_text
from .elements import ClauseElement
from .elements import ColumnClause
@@ -2884,6 +2885,11 @@ class CheckConstraint(ColumnCollectionConstraint):
_allow_multiple_tables = True
+ @_document_text_coercion(
+ "sqltext",
+ ":class:`.CheckConstraint`",
+ ":paramref:`.CheckConstraint.sqltext`",
+ )
def __init__(
self,
sqltext,
@@ -2925,7 +2931,7 @@ class CheckConstraint(ColumnCollectionConstraint):
"""
- self.sqltext = _literal_as_text(sqltext, warn=False)
+ self.sqltext = _literal_as_text(sqltext, allow_coercion_to_text=True)
columns = []
visitors.traverse(self.sqltext, {}, {"column": columns.append})
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index a5dee068c..ac08604f5 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -31,11 +31,13 @@ from .elements import _clause_element_as_expr
from .elements import _clone
from .elements import _cloned_difference
from .elements import _cloned_intersection
+from .elements import _document_text_coercion
from .elements import _expand_cloned
from .elements import _interpret_as_column_or_from
from .elements import _literal_and_labels_as_label_reference
from .elements import _literal_as_label_reference
from .elements import _literal_as_text
+from .elements import _no_text_coercion
from .elements import _select_iterables
from .elements import and_
from .elements import BindParameter
@@ -43,7 +45,6 @@ from .elements import ClauseElement
from .elements import ClauseList
from .elements import Grouping
from .elements import literal_column
-from .elements import TextClause
from .elements import True_
from .elements import UnaryExpression
from .. import exc
@@ -55,14 +56,7 @@ def _interpret_as_from(element):
insp = inspection.inspect(element, raiseerr=False)
if insp is None:
if isinstance(element, util.string_types):
- util.warn_limited(
- "Textual SQL FROM expression %(expr)r should be "
- "explicitly declared as text(%(expr)r), "
- "or use table(%(expr)r) for more specificity",
- {"expr": util.ellipses_string(element)},
- )
-
- return TextClause(util.text_type(element))
+ _no_text_coercion(element)
try:
return insp.selectable
except AttributeError:
@@ -266,6 +260,11 @@ class HasPrefixes(object):
_prefixes = ()
@_generative
+ @_document_text_coercion(
+ "expr",
+ ":meth:`.HasPrefixes.prefix_with`",
+ ":paramref:`.HasPrefixes.prefix_with.*expr`",
+ )
def prefix_with(self, *expr, **kw):
r"""Add one or more expressions following the statement keyword, i.e.
SELECT, INSERT, UPDATE, or DELETE. Generative.
@@ -297,7 +296,10 @@ class HasPrefixes(object):
def _setup_prefixes(self, prefixes, dialect=None):
self._prefixes = self._prefixes + tuple(
- [(_literal_as_text(p, warn=False), dialect) for p in prefixes]
+ [
+ (_literal_as_text(p, allow_coercion_to_text=True), dialect)
+ for p in prefixes
+ ]
)
@@ -305,6 +307,11 @@ class HasSuffixes(object):
_suffixes = ()
@_generative
+ @_document_text_coercion(
+ "expr",
+ ":meth:`.HasSuffixes.suffix_with`",
+ ":paramref:`.HasSuffixes.suffix_with.*expr`",
+ )
def suffix_with(self, *expr, **kw):
r"""Add one or more expressions following the statement as a whole.
@@ -335,7 +342,10 @@ class HasSuffixes(object):
def _setup_suffixes(self, suffixes, dialect=None):
self._suffixes = self._suffixes + tuple(
- [(_literal_as_text(p, warn=False), dialect) for p in suffixes]
+ [
+ (_literal_as_text(p, allow_coercion_to_text=True), dialect)
+ for p in suffixes
+ ]
)
diff --git a/lib/sqlalchemy/util/__init__.py b/lib/sqlalchemy/util/__init__.py
index 1e54ef80b..2f3deb191 100644
--- a/lib/sqlalchemy/util/__init__.py
+++ b/lib/sqlalchemy/util/__init__.py
@@ -93,6 +93,7 @@ from .deprecations import inject_docstring_text # noqa
from .deprecations import pending_deprecation # noqa
from .deprecations import warn_deprecated # noqa
from .deprecations import warn_pending_deprecation # noqa
+from .langhelpers import add_parameter_text # noqa
from .langhelpers import as_interface # noqa
from .langhelpers import asbool # noqa
from .langhelpers import asint # noqa
diff --git a/lib/sqlalchemy/util/deprecations.py b/lib/sqlalchemy/util/deprecations.py
index a43acc72e..9abf4a6be 100644
--- a/lib/sqlalchemy/util/deprecations.py
+++ b/lib/sqlalchemy/util/deprecations.py
@@ -9,11 +9,12 @@
functionality."""
import re
-import textwrap
import warnings
from . import compat
from .langhelpers import decorator
+from .langhelpers import inject_docstring_text
+from .langhelpers import inject_param_text
from .. import exc
@@ -247,64 +248,3 @@ def _decorate_with_warning(func, wtype, message, docstring_header=None):
decorated.__doc__ = doc
decorated._sa_warn = lambda: warnings.warn(message, wtype, stacklevel=3)
return decorated
-
-
-def _dedent_docstring(text):
- split_text = text.split("\n", 1)
- if len(split_text) == 1:
- return text
- else:
- firstline, remaining = split_text
- if not firstline.startswith(" "):
- return firstline + "\n" + textwrap.dedent(remaining)
- else:
- return textwrap.dedent(text)
-
-
-def inject_docstring_text(doctext, injecttext, pos):
- doctext = _dedent_docstring(doctext or "")
- lines = doctext.split("\n")
- injectlines = textwrap.dedent(injecttext).split("\n")
- if injectlines[0]:
- injectlines.insert(0, "")
-
- blanks = [num for num, line in enumerate(lines) if not line.strip()]
- blanks.insert(0, 0)
-
- inject_pos = blanks[min(pos, len(blanks) - 1)]
-
- lines = lines[0:inject_pos] + injectlines + lines[inject_pos:]
- return "\n".join(lines)
-
-
-def inject_param_text(doctext, inject_params):
- doclines = doctext.splitlines()
- lines = []
-
- to_inject = None
- while doclines:
- line = doclines.pop(0)
- if to_inject is None:
- m = re.match(r"(\s+):param (.+?):", line)
- if m:
- param = m.group(2)
- if param in inject_params:
- # default indent to that of :param: plus one
- indent = " " * len(m.group(1)) + " "
-
- # but if the next line has text, use that line's
- # indentntation
- if doclines:
- m2 = re.match(r"(\s+)\S", doclines[0])
- if m2:
- indent = " " * len(m2.group(1))
-
- to_inject = indent + inject_params[param]
- elif not line.rstrip():
- lines.append(line)
- lines.append(to_inject)
- lines.append("\n")
- to_inject = None
- lines.append(line)
-
- return "\n".join(lines)
diff --git a/lib/sqlalchemy/util/langhelpers.py b/lib/sqlalchemy/util/langhelpers.py
index bfe3fd275..198a23a59 100644
--- a/lib/sqlalchemy/util/langhelpers.py
+++ b/lib/sqlalchemy/util/langhelpers.py
@@ -16,6 +16,7 @@ import itertools
import operator
import re
import sys
+import textwrap
import types
import warnings
@@ -1572,3 +1573,82 @@ def quoted_token_parser(value):
idx += 1
return ["".join(token) for token in result]
+
+
+def add_parameter_text(params, text):
+ params = _collections.to_list(params)
+
+ def decorate(fn):
+ doc = fn.__doc__ is not None and fn.__doc__ or ""
+ if doc:
+ doc = inject_param_text(doc, {param: text for param in params})
+ fn.__doc__ = doc
+ return fn
+
+ return decorate
+
+
+def _dedent_docstring(text):
+ split_text = text.split("\n", 1)
+ if len(split_text) == 1:
+ return text
+ else:
+ firstline, remaining = split_text
+ if not firstline.startswith(" "):
+ return firstline + "\n" + textwrap.dedent(remaining)
+ else:
+ return textwrap.dedent(text)
+
+
+def inject_docstring_text(doctext, injecttext, pos):
+ doctext = _dedent_docstring(doctext or "")
+ lines = doctext.split("\n")
+ injectlines = textwrap.dedent(injecttext).split("\n")
+ if injectlines[0]:
+ injectlines.insert(0, "")
+
+ blanks = [num for num, line in enumerate(lines) if not line.strip()]
+ blanks.insert(0, 0)
+
+ inject_pos = blanks[min(pos, len(blanks) - 1)]
+
+ lines = lines[0:inject_pos] + injectlines + lines[inject_pos:]
+ return "\n".join(lines)
+
+
+def inject_param_text(doctext, inject_params):
+ doclines = doctext.splitlines()
+ lines = []
+
+ to_inject = None
+ while doclines:
+ line = doclines.pop(0)
+ if to_inject is None:
+ m = re.match(r"(\s+):param (?:\\\*\*?)?(.+?):", line)
+ if m:
+ param = m.group(2)
+ if param in inject_params:
+ # default indent to that of :param: plus one
+ indent = " " * len(m.group(1)) + " "
+
+ # but if the next line has text, use that line's
+ # indentntation
+ if doclines:
+ m2 = re.match(r"(\s+)\S", doclines[0])
+ if m2:
+ indent = " " * len(m2.group(1))
+
+ to_inject = indent + inject_params[param]
+ elif line.lstrip().startswith(":param "):
+ lines.append("\n")
+ lines.append(to_inject)
+ lines.append("\n")
+ to_inject = None
+ elif not line.rstrip():
+ lines.append(line)
+ lines.append(to_inject)
+ lines.append("\n")
+ to_inject = None
+ lines.append(line)
+
+ return "\n".join(lines)