summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy
diff options
context:
space:
mode:
authorMike Bayer <mike_mp@zzzcomputing.com>2020-03-23 14:52:05 -0400
committerMike Bayer <mike_mp@zzzcomputing.com>2020-03-24 14:25:41 -0400
commitcadfc608d63f4e0df46c0daaa28902423fd88d71 (patch)
tree63b05c466c5c0cbebae5515d7790291305e66cc6 /lib/sqlalchemy
parentfd74bd8eea3f3696c43ca0336ed4e437036c43c5 (diff)
downloadsqlalchemy-cadfc608d63f4e0df46c0daaa28902423fd88d71.tar.gz
Convert schema_translate to a post compile
Revised the :paramref:`.Connection.execution_options.schema_translate_map` feature such that the processing of the SQL statement to receive a specific schema name occurs within the execution phase of the statement, rather than at the compile phase. This is to support the statement being efficiently cached. Previously, the current schema being rendered into the statement for a particular run would be considered as part of the cache key itself, meaning that for a run against hundreds of schemas, there would be hundreds of cache keys, rendering the cache much less performant. The new behavior is that the rendering is done in a similar manner as the "post compile" rendering added in 1.4 as part of :ticket:`4645`, :ticket:`4808`. Fixes: #5004 Change-Id: Ia5c89eb27cc8dc2c5b8e76d6c07c46290a7901b6
Diffstat (limited to 'lib/sqlalchemy')
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py3
-rw-r--r--lib/sqlalchemy/engine/base.py68
-rw-r--r--lib/sqlalchemy/engine/default.py27
-rw-r--r--lib/sqlalchemy/engine/mock.py4
-rw-r--r--lib/sqlalchemy/engine/reflection.py3
-rw-r--r--lib/sqlalchemy/sql/compiler.py100
-rw-r--r--lib/sqlalchemy/sql/schema.py57
-rw-r--r--lib/sqlalchemy/sql/selectable.py3
-rw-r--r--lib/sqlalchemy/sql/sqltypes.py5
-rw-r--r--lib/sqlalchemy/testing/assertions.py8
-rw-r--r--lib/sqlalchemy/testing/assertsql.py16
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py1
12 files changed, 147 insertions, 148 deletions
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index a63ce0033..31425d4c0 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -1631,6 +1631,9 @@ class SQLiteDialect(default.DefaultDialect):
)
return bool(info)
+ def _get_default_schema_name(self, connection):
+ return "main"
+
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
if schema is not None:
diff --git a/lib/sqlalchemy/engine/base.py b/lib/sqlalchemy/engine/base.py
index aa21fb13b..4ed3b9af7 100644
--- a/lib/sqlalchemy/engine/base.py
+++ b/lib/sqlalchemy/engine/base.py
@@ -17,7 +17,6 @@ from .. import inspection
from .. import log
from .. import util
from ..sql import compiler
-from ..sql import schema
from ..sql import util as sql_util
@@ -51,21 +50,7 @@ class Connection(Connectable):
"""
- schema_for_object = schema._schema_getter(None)
- """Return the ".schema" attribute for an object.
-
- Used for :class:`.Table`, :class:`.Sequence` and similar objects,
- and takes into account
- the :paramref:`.Connection.execution_options.schema_translate_map`
- parameter.
-
- .. versionadded:: 1.1
-
- .. seealso::
-
- :ref:`schema_translating`
-
- """
+ _schema_translate_map = None
def __init__(
self,
@@ -92,7 +77,7 @@ class Connection(Connectable):
self.should_close_with_result = False
self.dispatch = _dispatch
self._has_events = _branch_from._has_events
- self.schema_for_object = _branch_from.schema_for_object
+ self._schema_translate_map = _branch_from._schema_translate_map
else:
self.__connection = (
connection
@@ -122,6 +107,24 @@ class Connection(Connectable):
if self._has_events or self.engine._has_events:
self.dispatch.engine_connect(self, self.__branch)
+ def schema_for_object(self, obj):
+ """return the schema name for the given schema item taking into
+ account current schema translate map.
+
+ """
+
+ name = obj.schema
+ schema_translate_map = self._schema_translate_map
+
+ if (
+ schema_translate_map
+ and name in schema_translate_map
+ and obj._use_schema_map
+ ):
+ return schema_translate_map[name]
+ else:
+ return name
+
def _branch(self):
"""Return a new Connection which references this Connection's
engine and connection; but does not have close_with_result enabled,
@@ -1066,10 +1069,7 @@ class Connection(Connectable):
dialect = self.dialect
compiled = ddl.compile(
- dialect=dialect,
- schema_translate_map=self.schema_for_object
- if not self.schema_for_object.is_default
- else None,
+ dialect=dialect, schema_translate_map=self._schema_translate_map
)
ret = self._execute_context(
dialect,
@@ -1103,7 +1103,7 @@ class Connection(Connectable):
dialect,
elem,
tuple(sorted(keys)),
- self.schema_for_object.hash_key,
+ bool(self._schema_translate_map),
len(distilled_params) > 1,
)
compiled_sql = self._execution_options["compiled_cache"].get(key)
@@ -1112,9 +1112,7 @@ class Connection(Connectable):
dialect=dialect,
column_keys=keys,
inline=len(distilled_params) > 1,
- schema_translate_map=self.schema_for_object
- if not self.schema_for_object.is_default
- else None,
+ schema_translate_map=self._schema_translate_map,
linting=self.dialect.compiler_linting
| compiler.WARN_LINTING,
)
@@ -1124,9 +1122,7 @@ class Connection(Connectable):
dialect=dialect,
column_keys=keys,
inline=len(distilled_params) > 1,
- schema_translate_map=self.schema_for_object
- if not self.schema_for_object.is_default
- else None,
+ schema_translate_map=self._schema_translate_map,
linting=self.dialect.compiler_linting | compiler.WARN_LINTING,
)
@@ -1974,21 +1970,7 @@ class Engine(Connectable, log.Identified):
_has_events = False
_connection_cls = Connection
- schema_for_object = schema._schema_getter(None)
- """Return the ".schema" attribute for an object.
-
- Used for :class:`.Table`, :class:`.Sequence` and similar objects,
- and takes into account
- the :paramref:`.Connection.execution_options.schema_translate_map`
- parameter.
-
- .. versionadded:: 1.1
-
- .. seealso::
-
- :ref:`schema_translating`
-
- """
+ _schema_translate_map = None
def __init__(
self,
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index b151b6e48..d0940decf 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -28,7 +28,6 @@ from .. import types as sqltypes
from .. import util
from ..sql import compiler
from ..sql import expression
-from ..sql import schema
from ..sql.elements import quoted_name
AUTOCOMMIT_REGEXP = re.compile(
@@ -129,6 +128,8 @@ class DefaultDialect(interfaces.Dialect):
server_version_info = None
+ default_schema_name = None
+
construct_arguments = None
"""Optional set of argument specifiers for various SQLAlchemy
constructs, typically schema items.
@@ -495,20 +496,18 @@ class DefaultDialect(interfaces.Dialect):
self._set_connection_isolation(connection, isolation_level)
if "schema_translate_map" in opts:
- getter = schema._schema_getter(opts["schema_translate_map"])
- engine.schema_for_object = getter
+ engine._schema_translate_map = map_ = opts["schema_translate_map"]
@event.listens_for(engine, "engine_connect")
def set_schema_translate_map(connection, branch):
- connection.schema_for_object = getter
+ connection._schema_translate_map = map_
def set_connection_execution_options(self, connection, opts):
if "isolation_level" in opts:
self._set_connection_isolation(connection, opts["isolation_level"])
if "schema_translate_map" in opts:
- getter = schema._schema_getter(opts["schema_translate_map"])
- connection.schema_for_object = getter
+ connection._schema_translate_map = opts["schema_translate_map"]
def _set_connection_isolation(self, connection, level):
if connection.in_transaction():
@@ -701,11 +700,17 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
self.execution_options = dict(self.execution_options)
self.execution_options.update(connection._execution_options)
+ self.unicode_statement = util.text_type(compiled)
+ if compiled.schema_translate_map:
+ rst = compiled.preparer._render_schema_translates
+ self.unicode_statement = rst(
+ self.unicode_statement, connection._schema_translate_map
+ )
+
if not dialect.supports_unicode_statements:
- self.unicode_statement = util.text_type(compiled)
self.statement = dialect._encoder(self.unicode_statement)[0]
else:
- self.statement = self.unicode_statement = util.text_type(compiled)
+ self.statement = self.unicode_statement
self.cursor = self.create_cursor()
self.compiled_parameters = []
@@ -807,6 +812,12 @@ class DefaultExecutionContext(interfaces.ExecutionContext):
elif compiled.positional:
positiontup = self.compiled.positiontup
+ if compiled.schema_translate_map:
+ rst = compiled.preparer._render_schema_translates
+ self.unicode_statement = rst(
+ self.unicode_statement, connection._schema_translate_map
+ )
+
# final self.unicode_statement is now assigned, encode if needed
# by dialect
if not dialect.supports_unicode_statements:
diff --git a/lib/sqlalchemy/engine/mock.py b/lib/sqlalchemy/engine/mock.py
index 570ee2d04..bda9e91b5 100644
--- a/lib/sqlalchemy/engine/mock.py
+++ b/lib/sqlalchemy/engine/mock.py
@@ -11,7 +11,6 @@ from . import base
from . import url as _url
from .. import util
from ..sql import ddl
-from ..sql import schema
class MockConnection(base.Connectable):
@@ -23,7 +22,8 @@ class MockConnection(base.Connectable):
dialect = property(attrgetter("_dialect"))
name = property(lambda s: s._dialect.name)
- schema_for_object = schema._schema_getter(None)
+ def schema_for_object(self, obj):
+ return obj.schema
def connect(self, **kwargs):
return self
diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py
index 203369ed8..8ef0d572f 100644
--- a/lib/sqlalchemy/engine/reflection.py
+++ b/lib/sqlalchemy/engine/reflection.py
@@ -701,7 +701,8 @@ class Inspector(object):
dialect = self.bind.dialect
- schema = self.bind.schema_for_object(table)
+ with self._operation_context() as conn:
+ schema = conn.schema_for_object(table)
table_name = table.name
diff --git a/lib/sqlalchemy/sql/compiler.py b/lib/sqlalchemy/sql/compiler.py
index 1f183b5c1..ae9c3c73a 100644
--- a/lib/sqlalchemy/sql/compiler.py
+++ b/lib/sqlalchemy/sql/compiler.py
@@ -26,6 +26,7 @@ To generate user-defined SQL strings, see
import collections
import contextlib
import itertools
+import operator
import re
from . import base
@@ -39,6 +40,7 @@ from . import schema
from . import selectable
from . import sqltypes
from .base import NO_ARG
+from .elements import quoted_name
from .. import exc
from .. import util
@@ -369,6 +371,8 @@ class Compiled(object):
_cached_metadata = None
+ schema_translate_map = None
+
execution_options = util.immutabledict()
"""
Execution options propagated from the statement. In some cases,
@@ -381,6 +385,7 @@ class Compiled(object):
statement,
bind=None,
schema_translate_map=None,
+ render_schema_translate=False,
compile_kwargs=util.immutabledict(),
):
"""Construct a new :class:`.Compiled` object.
@@ -411,6 +416,7 @@ class Compiled(object):
self.bind = bind
self.preparer = self.dialect.identifier_preparer
if schema_translate_map:
+ self.schema_translate_map = schema_translate_map
self.preparer = self.preparer._with_schema_translate(
schema_translate_map
)
@@ -422,6 +428,11 @@ class Compiled(object):
self.execution_options = statement._execution_options
self.string = self.process(self.statement, **compile_kwargs)
+ if render_schema_translate:
+ self.string = self.preparer._render_schema_translates(
+ self.string, schema_translate_map
+ )
+
@util.deprecated(
"0.7",
"The :meth:`.Compiled.compile` method is deprecated and will be "
@@ -3365,18 +3376,18 @@ class DDLCompiler(Compiled):
return self.sql_compiler.post_process_text(ddl.statement % context)
- def visit_create_schema(self, create):
+ def visit_create_schema(self, create, **kw):
schema = self.preparer.format_schema(create.element)
return "CREATE SCHEMA " + schema
- def visit_drop_schema(self, drop):
+ def visit_drop_schema(self, drop, **kw):
schema = self.preparer.format_schema(drop.element)
text = "DROP SCHEMA " + schema
if drop.cascade:
text += " CASCADE"
return text
- def visit_create_table(self, create):
+ def visit_create_table(self, create, **kw):
table = create.element
preparer = self.preparer
@@ -3426,7 +3437,7 @@ class DDLCompiler(Compiled):
text += "\n)%s\n\n" % self.post_create_table(table)
return text
- def visit_create_column(self, create, first_pk=False):
+ def visit_create_column(self, create, first_pk=False, **kw):
column = create.element
if column.system:
@@ -3442,7 +3453,7 @@ class DDLCompiler(Compiled):
return text
def create_table_constraints(
- self, table, _include_foreign_key_constraints=None
+ self, table, _include_foreign_key_constraints=None, **kw
):
# On some DB order is significant: visit PK first, then the
@@ -3482,10 +3493,10 @@ class DDLCompiler(Compiled):
if p is not None
)
- def visit_drop_table(self, drop):
+ def visit_drop_table(self, drop, **kw):
return "\nDROP TABLE " + self.preparer.format_table(drop.element)
- def visit_drop_view(self, drop):
+ def visit_drop_view(self, drop, **kw):
return "\nDROP VIEW " + self.preparer.format_table(drop.element)
def _verify_index_table(self, index):
@@ -3495,7 +3506,7 @@ class DDLCompiler(Compiled):
)
def visit_create_index(
- self, create, include_schema=False, include_table_schema=True
+ self, create, include_schema=False, include_table_schema=True, **kw
):
index = create.element
self._verify_index_table(index)
@@ -3521,7 +3532,7 @@ class DDLCompiler(Compiled):
)
return text
- def visit_drop_index(self, drop):
+ def visit_drop_index(self, drop, **kw):
index = drop.element
if index.name is None:
@@ -3548,13 +3559,13 @@ class DDLCompiler(Compiled):
index_name = schema_name + "." + index_name
return index_name
- def visit_add_constraint(self, create):
+ def visit_add_constraint(self, create, **kw):
return "ALTER TABLE %s ADD %s" % (
self.preparer.format_table(create.element.table),
self.process(create.element),
)
- def visit_set_table_comment(self, create):
+ def visit_set_table_comment(self, create, **kw):
return "COMMENT ON TABLE %s IS %s" % (
self.preparer.format_table(create.element),
self.sql_compiler.render_literal_value(
@@ -3562,12 +3573,12 @@ class DDLCompiler(Compiled):
),
)
- def visit_drop_table_comment(self, drop):
+ def visit_drop_table_comment(self, drop, **kw):
return "COMMENT ON TABLE %s IS NULL" % self.preparer.format_table(
drop.element
)
- def visit_set_column_comment(self, create):
+ def visit_set_column_comment(self, create, **kw):
return "COMMENT ON COLUMN %s IS %s" % (
self.preparer.format_column(
create.element, use_table=True, use_schema=True
@@ -3577,12 +3588,12 @@ class DDLCompiler(Compiled):
),
)
- def visit_drop_column_comment(self, drop):
+ def visit_drop_column_comment(self, drop, **kw):
return "COMMENT ON COLUMN %s IS NULL" % self.preparer.format_column(
drop.element, use_table=True
)
- def visit_create_sequence(self, create):
+ def visit_create_sequence(self, create, **kw):
text = "CREATE SEQUENCE %s" % self.preparer.format_sequence(
create.element
)
@@ -3606,10 +3617,10 @@ class DDLCompiler(Compiled):
text += " CYCLE"
return text
- def visit_drop_sequence(self, drop):
+ def visit_drop_sequence(self, drop, **kw):
return "DROP SEQUENCE %s" % self.preparer.format_sequence(drop.element)
- def visit_drop_constraint(self, drop):
+ def visit_drop_constraint(self, drop, **kw):
constraint = drop.element
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
@@ -3671,7 +3682,7 @@ class DDLCompiler(Compiled):
else:
return self.visit_check_constraint(constraint)
- def visit_check_constraint(self, constraint):
+ def visit_check_constraint(self, constraint, **kw):
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
@@ -3683,7 +3694,7 @@ class DDLCompiler(Compiled):
text += self.define_constraint_deferrability(constraint)
return text
- def visit_column_check_constraint(self, constraint):
+ def visit_column_check_constraint(self, constraint, **kw):
text = ""
if constraint.name is not None:
formatted_name = self.preparer.format_constraint(constraint)
@@ -3695,7 +3706,7 @@ class DDLCompiler(Compiled):
text += self.define_constraint_deferrability(constraint)
return text
- def visit_primary_key_constraint(self, constraint):
+ def visit_primary_key_constraint(self, constraint, **kw):
if len(constraint) == 0:
return ""
text = ""
@@ -3715,7 +3726,7 @@ class DDLCompiler(Compiled):
text += self.define_constraint_deferrability(constraint)
return text
- def visit_foreign_key_constraint(self, constraint):
+ def visit_foreign_key_constraint(self, constraint, **kw):
preparer = self.preparer
text = ""
if constraint.name is not None:
@@ -3744,7 +3755,7 @@ class DDLCompiler(Compiled):
return preparer.format_table(table)
- def visit_unique_constraint(self, constraint):
+ def visit_unique_constraint(self, constraint, **kw):
if len(constraint) == 0:
return ""
text = ""
@@ -3789,7 +3800,7 @@ class DDLCompiler(Compiled):
text += " MATCH %s" % constraint.match
return text
- def visit_computed_column(self, generated):
+ def visit_computed_column(self, generated, **kw):
text = "GENERATED ALWAYS AS (%s)" % self.sql_compiler.process(
generated.sqltext, include_table=False, literal_binds=True
)
@@ -3975,7 +3986,16 @@ class IdentifierPreparer(object):
illegal_initial_characters = ILLEGAL_INITIAL_CHARACTERS
- schema_for_object = schema._schema_getter(None)
+ schema_for_object = operator.attrgetter("schema")
+ """Return the .schema attribute for an object.
+
+ For the default IdentifierPreparer, the schema for an object is always
+ the value of the ".schema" attribute. if the preparer is replaced
+ with one that has a non-empty schema_translate_map, the value of the
+ ".schema" attribute is rendered a symbol that will be converted to a
+ real schema name from the mapping post-compile.
+
+ """
def __init__(
self,
@@ -4016,9 +4036,39 @@ class IdentifierPreparer(object):
def _with_schema_translate(self, schema_translate_map):
prep = self.__class__.__new__(self.__class__)
prep.__dict__.update(self.__dict__)
- prep.schema_for_object = schema._schema_getter(schema_translate_map)
+
+ def symbol_getter(obj):
+ name = obj.schema
+ if name in schema_translate_map and obj._use_schema_map:
+ return quoted_name(
+ "[SCHEMA_%s]" % (name or "_none"), quote=False
+ )
+ else:
+ return obj.schema
+
+ prep.schema_for_object = symbol_getter
return prep
+ def _render_schema_translates(self, statement, schema_translate_map):
+ d = schema_translate_map
+ if None in d:
+ d["_none"] = d[None]
+
+ def replace(m):
+ name = m.group(2)
+ effective_schema = d[name]
+ if not effective_schema:
+ effective_schema = self.dialect.default_schema_name
+ if not effective_schema:
+ # TODO: no coverage here
+ raise exc.CompileError(
+ "Dialect has no default schema name; can't "
+ "use None as dynamic schema target."
+ )
+ return self.quote(effective_schema)
+
+ return re.sub(r"(\[SCHEMA_([\w\d_]+)\])", replace, statement)
+
def _escape_identifier(self, value):
"""Escape an identifier.
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index 69f60ba24..02c14d751 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -31,7 +31,6 @@ as components in SQL expressions.
from __future__ import absolute_import
import collections
-import operator
import sqlalchemy
from . import coercions
@@ -143,8 +142,7 @@ class SchemaItem(SchemaEventTarget, visitors.Visitable):
schema_item.dispatch._update(self.dispatch)
return schema_item
- def _translate_schema(self, effective_schema, map_):
- return map_.get(effective_schema, effective_schema)
+ _use_schema_map = True
class Table(DialectKWArgs, SchemaItem, TableClause):
@@ -4270,59 +4268,6 @@ class ThreadLocalMetaData(MetaData):
e.dispose()
-class _SchemaTranslateMap(object):
- """Provide translation of schema names based on a mapping.
-
- Also provides helpers for producing cache keys and optimized
- access when no mapping is present.
-
- Used by the :paramref:`.Connection.execution_options.schema_translate_map`
- feature.
-
- .. versionadded:: 1.1
-
-
- """
-
- __slots__ = "map_", "__call__", "hash_key", "is_default"
-
- _default_schema_getter = operator.attrgetter("schema")
-
- def __init__(self, map_):
- self.map_ = map_
- if map_ is not None:
-
- def schema_for_object(obj):
- effective_schema = self._default_schema_getter(obj)
- effective_schema = obj._translate_schema(
- effective_schema, map_
- )
- return effective_schema
-
- self.__call__ = schema_for_object
- self.hash_key = ";".join(
- "%s=%s" % (k, map_[k]) for k in sorted(map_, key=str)
- )
- self.is_default = False
- else:
- self.hash_key = 0
- self.__call__ = self._default_schema_getter
- self.is_default = True
-
- @classmethod
- def _schema_getter(cls, map_):
- if map_ is None:
- return _default_schema_map
- elif isinstance(map_, _SchemaTranslateMap):
- return map_
- else:
- return _SchemaTranslateMap(map_)
-
-
-_default_schema_map = _SchemaTranslateMap(None)
-_schema_getter = _SchemaTranslateMap._schema_getter
-
-
class Computed(FetchedValue, SchemaItem):
"""Defines a generated column, i.e. "GENERATED ALWAYS AS" syntax.
diff --git a/lib/sqlalchemy/sql/selectable.py b/lib/sqlalchemy/sql/selectable.py
index 45b9e7f9d..ab13b21c4 100644
--- a/lib/sqlalchemy/sql/selectable.py
+++ b/lib/sqlalchemy/sql/selectable.py
@@ -346,8 +346,7 @@ class FromClause(HasMemoized, roles.AnonymizedFromClauseRole, Selectable):
_is_from_clause = True
_is_join = False
- def _translate_schema(self, effective_schema, map_):
- return effective_schema
+ _use_schema_map = False
_memoized_property = util.group_expirable_memoized_property(["_columns"])
diff --git a/lib/sqlalchemy/sql/sqltypes.py b/lib/sqlalchemy/sql/sqltypes.py
index 3d69d1177..e106684bc 100644
--- a/lib/sqlalchemy/sql/sqltypes.py
+++ b/lib/sqlalchemy/sql/sqltypes.py
@@ -1000,6 +1000,8 @@ class SchemaType(SchemaEventTarget):
"""
+ _use_schema_map = True
+
def __init__(
self,
name=None,
@@ -1030,9 +1032,6 @@ class SchemaType(SchemaEventTarget):
util.portable_instancemethod(self._on_metadata_drop),
)
- def _translate_schema(self, effective_schema, map_):
- return map_.get(effective_schema, effective_schema)
-
def _set_parent(self, column):
column._on_table_attach(util.portable_instancemethod(self._set_table))
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index e0bf4326e..7dada1394 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -352,6 +352,8 @@ class AssertsCompiledSQL(object):
literal_binds=False,
render_postcompile=False,
schema_translate_map=None,
+ render_schema_translate=False,
+ default_schema_name=None,
inline_flag=None,
):
if use_default_dialect:
@@ -371,6 +373,9 @@ class AssertsCompiledSQL(object):
elif isinstance(dialect, util.string_types):
dialect = url.URL(dialect).get_dialect()()
+ if default_schema_name:
+ dialect.default_schema_name = default_schema_name
+
kw = {}
compile_kwargs = {}
@@ -386,6 +391,9 @@ class AssertsCompiledSQL(object):
if render_postcompile:
compile_kwargs["render_postcompile"] = True
+ if render_schema_translate:
+ kw["render_schema_translate"] = True
+
from sqlalchemy import orm
if isinstance(clause, orm.Query):
diff --git a/lib/sqlalchemy/testing/assertsql.py b/lib/sqlalchemy/testing/assertsql.py
index e38c7ddd8..f0da69400 100644
--- a/lib/sqlalchemy/testing/assertsql.py
+++ b/lib/sqlalchemy/testing/assertsql.py
@@ -91,21 +91,23 @@ class CompiledSQL(SQLMatchRule):
context = execute_observed.context
compare_dialect = self._compile_dialect(execute_observed)
+
+ if "schema_translate_map" in context.execution_options:
+ map_ = context.execution_options["schema_translate_map"]
+ else:
+ map_ = None
+
if isinstance(context.compiled.statement, _DDLCompiles):
+
compiled = context.compiled.statement.compile(
- dialect=compare_dialect,
- schema_translate_map=context.execution_options.get(
- "schema_translate_map"
- ),
+ dialect=compare_dialect, schema_translate_map=map_
)
else:
compiled = context.compiled.statement.compile(
dialect=compare_dialect,
column_keys=context.compiled.column_keys,
inline=context.compiled.inline,
- schema_translate_map=context.execution_options.get(
- "schema_translate_map"
- ),
+ schema_translate_map=map_,
)
_received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
parameters = execute_observed.parameters
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
index 473c98116..68a43feb7 100644
--- a/lib/sqlalchemy/testing/suite/test_reflection.py
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -360,7 +360,6 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.schema_reflection
def test_dialect_initialize(self):
engine = engines.testing_engine()
- assert not hasattr(engine.dialect, "default_schema_name")
inspect(engine)
assert hasattr(engine.dialect, "default_schema_name")