summaryrefslogtreecommitdiff
path: root/lib
diff options
context:
space:
mode:
Diffstat (limited to 'lib')
-rw-r--r--lib/sqlalchemy/dialects/mssql/base.py172
-rw-r--r--lib/sqlalchemy/dialects/mysql/base.py47
-rw-r--r--lib/sqlalchemy/dialects/oracle/base.py2240
-rw-r--r--lib/sqlalchemy/dialects/oracle/cx_oracle.py4
-rw-r--r--lib/sqlalchemy/dialects/oracle/dictionary.py495
-rw-r--r--lib/sqlalchemy/dialects/oracle/provision.py54
-rw-r--r--lib/sqlalchemy/dialects/oracle/types.py233
-rw-r--r--lib/sqlalchemy/dialects/postgresql/__init__.py31
-rw-r--r--lib/sqlalchemy/dialects/postgresql/_psycopg_common.py18
-rw-r--r--lib/sqlalchemy/dialects/postgresql/asyncpg.py5
-rw-r--r--lib/sqlalchemy/dialects/postgresql/base.py2624
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pg8000.py7
-rw-r--r--lib/sqlalchemy/dialects/postgresql/pg_catalog.py292
-rw-r--r--lib/sqlalchemy/dialects/postgresql/types.py485
-rw-r--r--lib/sqlalchemy/dialects/sqlite/base.py161
-rw-r--r--lib/sqlalchemy/engine/__init__.py2
-rw-r--r--lib/sqlalchemy/engine/default.py131
-rw-r--r--lib/sqlalchemy/engine/interfaces.py417
-rw-r--r--lib/sqlalchemy/engine/reflection.py1475
-rw-r--r--lib/sqlalchemy/sql/base.py2
-rw-r--r--lib/sqlalchemy/sql/cache_key.py38
-rw-r--r--lib/sqlalchemy/sql/schema.py42
-rw-r--r--lib/sqlalchemy/testing/assertions.py30
-rw-r--r--lib/sqlalchemy/testing/plugin/pytestplugin.py16
-rw-r--r--lib/sqlalchemy/testing/provision.py70
-rw-r--r--lib/sqlalchemy/testing/requirements.py46
-rw-r--r--lib/sqlalchemy/testing/schema.py15
-rw-r--r--lib/sqlalchemy/testing/suite/test_reflection.py1544
-rw-r--r--lib/sqlalchemy/testing/suite/test_sequence.py33
-rw-r--r--lib/sqlalchemy/testing/util.py39
-rw-r--r--lib/sqlalchemy/util/topological.py10
-rw-r--r--lib/sqlalchemy/util/typing.py2
32 files changed, 7768 insertions, 3012 deletions
diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py
index 12f495d6e..2a4362ccb 100644
--- a/lib/sqlalchemy/dialects/mssql/base.py
+++ b/lib/sqlalchemy/dialects/mssql/base.py
@@ -831,6 +831,7 @@ from ... import util
from ...engine import cursor as _cursor
from ...engine import default
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
from ...sql import coercions
from ...sql import compiler
from ...sql import elements
@@ -3010,55 +3011,16 @@ class MSDialect(default.DefaultDialect):
return self.schema_name
@_db_plus_owner
- def has_table(self, connection, tablename, dbname, owner, schema):
+ def has_table(self, connection, tablename, dbname, owner, schema, **kw):
self._ensure_has_table_connection(connection)
- if tablename.startswith("#"): # temporary table
- # mssql does not support temporary views
- # SQL Error [4103] [S0001]: "#v": Temporary views are not allowed
- tables = ischema.mssql_temp_table_columns
- s = sql.select(tables.c.table_name).where(
- tables.c.table_name.like(
- self._temp_table_name_like_pattern(tablename)
- )
- )
-
- # #7168: fetch all (not just first match) in case some other #temp
- # table with the same name happens to appear first
- table_names = connection.execute(s).scalars().fetchall()
- # #6910: verify it's not a temp table from another session
- for table_name in table_names:
- if bool(
- connection.scalar(
- text("SELECT object_id(:table_name)"),
- {"table_name": "tempdb.dbo.[{}]".format(table_name)},
- )
- ):
- return True
- else:
- return False
- else:
- tables = ischema.tables
-
- s = sql.select(tables.c.table_name).where(
- sql.and_(
- sql.or_(
- tables.c.table_type == "BASE TABLE",
- tables.c.table_type == "VIEW",
- ),
- tables.c.table_name == tablename,
- )
- )
-
- if owner:
- s = s.where(tables.c.table_schema == owner)
-
- c = connection.execute(s)
-
- return c.first() is not None
+ return self._internal_has_table(connection, tablename, owner, **kw)
+ @reflection.cache
@_db_plus_owner
- def has_sequence(self, connection, sequencename, dbname, owner, schema):
+ def has_sequence(
+ self, connection, sequencename, dbname, owner, schema, **kw
+ ):
sequences = ischema.sequences
s = sql.select(sequences.c.sequence_name).where(
@@ -3128,6 +3090,60 @@ class MSDialect(default.DefaultDialect):
return view_names
@reflection.cache
+ def _internal_has_table(self, connection, tablename, owner, **kw):
+ if tablename.startswith("#"): # temporary table
+ # mssql does not support temporary views
+ # SQL Error [4103] [S0001]: "#v": Temporary views are not allowed
+ tables = ischema.mssql_temp_table_columns
+
+ s = sql.select(tables.c.table_name).where(
+ tables.c.table_name.like(
+ self._temp_table_name_like_pattern(tablename)
+ )
+ )
+
+ # #7168: fetch all (not just first match) in case some other #temp
+ # table with the same name happens to appear first
+ table_names = connection.scalars(s).all()
+ # #6910: verify it's not a temp table from another session
+ for table_name in table_names:
+ if bool(
+ connection.scalar(
+ text("SELECT object_id(:table_name)"),
+ {"table_name": "tempdb.dbo.[{}]".format(table_name)},
+ )
+ ):
+ return True
+ else:
+ return False
+ else:
+ tables = ischema.tables
+
+ s = sql.select(tables.c.table_name).where(
+ sql.and_(
+ sql.or_(
+ tables.c.table_type == "BASE TABLE",
+ tables.c.table_type == "VIEW",
+ ),
+ tables.c.table_name == tablename,
+ )
+ )
+
+ if owner:
+ s = s.where(tables.c.table_schema == owner)
+
+ c = connection.execute(s)
+
+ return c.first() is not None
+
+ def _default_or_error(self, connection, tablename, owner, method, **kw):
+ # TODO: try to avoid having to run a separate query here
+ if self._internal_has_table(connection, tablename, owner, **kw):
+ return method()
+ else:
+ raise exc.NoSuchTableError(f"{owner}.{tablename}")
+
+ @reflection.cache
@_db_plus_owner
def get_indexes(self, connection, tablename, dbname, owner, schema, **kw):
filter_definition = (
@@ -3138,14 +3154,14 @@ class MSDialect(default.DefaultDialect):
rp = connection.execution_options(future_result=True).execute(
sql.text(
"select ind.index_id, ind.is_unique, ind.name, "
- "%s "
+ f"{filter_definition} "
"from sys.indexes as ind join sys.tables as tab on "
"ind.object_id=tab.object_id "
"join sys.schemas as sch on sch.schema_id=tab.schema_id "
"where tab.name = :tabname "
"and sch.name=:schname "
- "and ind.is_primary_key=0 and ind.type != 0"
- % filter_definition
+ "and ind.is_primary_key=0 and ind.type != 0 "
+ "order by ind.name "
)
.bindparams(
sql.bindparam("tabname", tablename, ischema.CoerceUnicode()),
@@ -3203,31 +3219,34 @@ class MSDialect(default.DefaultDialect):
"mssql_include"
] = index_info["include_columns"]
- return list(indexes.values())
+ if indexes:
+ return list(indexes.values())
+ else:
+ return self._default_or_error(
+ connection, tablename, owner, ReflectionDefaults.indexes, **kw
+ )
@reflection.cache
@_db_plus_owner
def get_view_definition(
self, connection, viewname, dbname, owner, schema, **kw
):
- rp = connection.execute(
+ view_def = connection.execute(
sql.text(
- "select definition from sys.sql_modules as mod, "
- "sys.views as views, "
- "sys.schemas as sch"
- " where "
- "mod.object_id=views.object_id and "
- "views.schema_id=sch.schema_id and "
- "views.name=:viewname and sch.name=:schname"
+ "select mod.definition "
+ "from sys.sql_modules as mod "
+ "join sys.views as views on mod.object_id = views.object_id "
+ "join sys.schemas as sch on views.schema_id = sch.schema_id "
+ "where views.name=:viewname and sch.name=:schname"
).bindparams(
sql.bindparam("viewname", viewname, ischema.CoerceUnicode()),
sql.bindparam("schname", owner, ischema.CoerceUnicode()),
)
- )
-
- if rp:
- view_def = rp.scalar()
+ ).scalar()
+ if view_def:
return view_def
+ else:
+ raise exc.NoSuchTableError(f"{owner}.{viewname}")
def _temp_table_name_like_pattern(self, tablename):
# LIKE uses '%' to match zero or more characters and '_' to match any
@@ -3417,7 +3436,12 @@ class MSDialect(default.DefaultDialect):
cols.append(cdict)
- return cols
+ if cols:
+ return cols
+ else:
+ return self._default_or_error(
+ connection, tablename, owner, ReflectionDefaults.columns, **kw
+ )
@reflection.cache
@_db_plus_owner
@@ -3450,7 +3474,16 @@ class MSDialect(default.DefaultDialect):
pkeys.append(row["COLUMN_NAME"])
if constraint_name is None:
constraint_name = row[C.c.constraint_name.name]
- return {"constrained_columns": pkeys, "name": constraint_name}
+ if pkeys:
+ return {"constrained_columns": pkeys, "name": constraint_name}
+ else:
+ return self._default_or_error(
+ connection,
+ tablename,
+ owner,
+ ReflectionDefaults.pk_constraint,
+ **kw,
+ )
@reflection.cache
@_db_plus_owner
@@ -3591,7 +3624,7 @@ index_info AS (
fkeys = util.defaultdict(fkey_rec)
- for r in connection.execute(s).fetchall():
+ for r in connection.execute(s).all():
(
_, # constraint schema
rfknm,
@@ -3632,4 +3665,13 @@ index_info AS (
local_cols.append(scol)
remote_cols.append(rcol)
- return list(fkeys.values())
+ if fkeys:
+ return list(fkeys.values())
+ else:
+ return self._default_or_error(
+ connection,
+ tablename,
+ owner,
+ ReflectionDefaults.foreign_keys,
+ **kw,
+ )
diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py
index 7c9a68236..502371be9 100644
--- a/lib/sqlalchemy/dialects/mysql/base.py
+++ b/lib/sqlalchemy/dialects/mysql/base.py
@@ -1056,6 +1056,7 @@ from ... import sql
from ... import util
from ...engine import default
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
from ...sql import coercions
from ...sql import compiler
from ...sql import elements
@@ -2648,7 +2649,8 @@ class MySQLDialect(default.DefaultDialect):
def _get_default_schema_name(self, connection):
return connection.exec_driver_sql("SELECT DATABASE()").scalar()
- def has_table(self, connection, table_name, schema=None):
+ @reflection.cache
+ def has_table(self, connection, table_name, schema=None, **kw):
self._ensure_has_table_connection(connection)
if schema is None:
@@ -2670,7 +2672,8 @@ class MySQLDialect(default.DefaultDialect):
)
return bool(rs.scalar())
- def has_sequence(self, connection, sequence_name, schema=None):
+ @reflection.cache
+ def has_sequence(self, connection, sequence_name, schema=None, **kw):
if not self.supports_sequences:
self._sequences_not_supported()
if not schema:
@@ -2847,14 +2850,20 @@ class MySQLDialect(default.DefaultDialect):
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
- return parsed_state.table_options
+ if parsed_state.table_options:
+ return parsed_state.table_options
+ else:
+ return ReflectionDefaults.table_options()
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
- return parsed_state.columns
+ if parsed_state.columns:
+ return parsed_state.columns
+ else:
+ return ReflectionDefaults.columns()
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
@@ -2866,7 +2875,7 @@ class MySQLDialect(default.DefaultDialect):
# There can be only one.
cols = [s[0] for s in key["columns"]]
return {"constrained_columns": cols, "name": None}
- return {"constrained_columns": [], "name": None}
+ return ReflectionDefaults.pk_constraint()
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
@@ -2909,7 +2918,7 @@ class MySQLDialect(default.DefaultDialect):
if self._needs_correct_for_88718_96365:
self._correct_for_mysql_bugs_88718_96365(fkeys, connection)
- return fkeys
+ return fkeys if fkeys else ReflectionDefaults.foreign_keys()
def _correct_for_mysql_bugs_88718_96365(self, fkeys, connection):
# Foreign key is always in lower case (MySQL 8.0)
@@ -3000,21 +3009,22 @@ class MySQLDialect(default.DefaultDialect):
connection, table_name, schema, **kw
)
- return [
+ cks = [
{"name": spec["name"], "sqltext": spec["sqltext"]}
for spec in parsed_state.ck_constraints
]
+ return cks if cks else ReflectionDefaults.check_constraints()
@reflection.cache
def get_table_comment(self, connection, table_name, schema=None, **kw):
parsed_state = self._parsed_state_or_create(
connection, table_name, schema, **kw
)
- return {
- "text": parsed_state.table_options.get(
- "%s_comment" % self.name, None
- )
- }
+ comment = parsed_state.table_options.get(f"{self.name}_comment", None)
+ if comment is not None:
+ return {"text": comment}
+ else:
+ return ReflectionDefaults.table_comment()
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
@@ -3058,7 +3068,8 @@ class MySQLDialect(default.DefaultDialect):
if flavor:
index_d["type"] = flavor
indexes.append(index_d)
- return indexes
+ indexes.sort(key=lambda d: d["name"] or "~") # sort None as last
+ return indexes if indexes else ReflectionDefaults.indexes()
@reflection.cache
def get_unique_constraints(
@@ -3068,7 +3079,7 @@ class MySQLDialect(default.DefaultDialect):
connection, table_name, schema, **kw
)
- return [
+ ucs = [
{
"name": key["name"],
"column_names": [col[0] for col in key["columns"]],
@@ -3077,6 +3088,11 @@ class MySQLDialect(default.DefaultDialect):
for key in parsed_state.keys
if key["type"] == "UNIQUE"
]
+ ucs.sort(key=lambda d: d["name"] or "~") # sort None as last
+ if ucs:
+ return ucs
+ else:
+ return ReflectionDefaults.unique_constraints()
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
@@ -3088,6 +3104,9 @@ class MySQLDialect(default.DefaultDialect):
sql = self._show_create_table(
connection, None, charset, full_name=full_name
)
+ if sql.upper().startswith("CREATE TABLE"):
+ # it's a table, not a view
+ raise exc.NoSuchTableError(full_name)
return sql
def _parsed_state_or_create(
diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py
index faac0deb7..fee098889 100644
--- a/lib/sqlalchemy/dialects/oracle/base.py
+++ b/lib/sqlalchemy/dialects/oracle/base.py
@@ -518,21 +518,52 @@ columns for non-unique indexes, all but the last column for unique indexes).
""" # noqa
-from itertools import groupby
+from __future__ import annotations
+
+from collections import defaultdict
+from functools import lru_cache
+from functools import wraps
import re
+from . import dictionary
+from .types import _OracleBoolean
+from .types import _OracleDate
+from .types import BFILE
+from .types import BINARY_DOUBLE
+from .types import BINARY_FLOAT
+from .types import DATE
+from .types import FLOAT
+from .types import INTERVAL
+from .types import LONG
+from .types import NCLOB
+from .types import NUMBER
+from .types import NVARCHAR2 # noqa
+from .types import OracleRaw # noqa
+from .types import RAW
+from .types import ROWID # noqa
+from .types import VARCHAR2 # noqa
from ... import Computed
from ... import exc
from ... import schema as sa_schema
from ... import sql
from ... import util
from ...engine import default
+from ...engine import ObjectKind
+from ...engine import ObjectScope
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
+from ...sql import and_
+from ...sql import bindparam
from ...sql import compiler
from ...sql import expression
+from ...sql import func
+from ...sql import null
+from ...sql import or_
+from ...sql import select
from ...sql import sqltypes
from ...sql import util as sql_util
from ...sql import visitors
+from ...sql.visitors import InternalTraversal
from ...types import BLOB
from ...types import CHAR
from ...types import CLOB
@@ -561,229 +592,6 @@ NO_ARG_FNS = set(
)
-class RAW(sqltypes._Binary):
- __visit_name__ = "RAW"
-
-
-OracleRaw = RAW
-
-
-class NCLOB(sqltypes.Text):
- __visit_name__ = "NCLOB"
-
-
-class VARCHAR2(VARCHAR):
- __visit_name__ = "VARCHAR2"
-
-
-NVARCHAR2 = NVARCHAR
-
-
-class NUMBER(sqltypes.Numeric, sqltypes.Integer):
- __visit_name__ = "NUMBER"
-
- def __init__(self, precision=None, scale=None, asdecimal=None):
- if asdecimal is None:
- asdecimal = bool(scale and scale > 0)
-
- super(NUMBER, self).__init__(
- precision=precision, scale=scale, asdecimal=asdecimal
- )
-
- def adapt(self, impltype):
- ret = super(NUMBER, self).adapt(impltype)
- # leave a hint for the DBAPI handler
- ret._is_oracle_number = True
- return ret
-
- @property
- def _type_affinity(self):
- if bool(self.scale and self.scale > 0):
- return sqltypes.Numeric
- else:
- return sqltypes.Integer
-
-
-class FLOAT(sqltypes.FLOAT):
- """Oracle FLOAT.
-
- This is the same as :class:`_sqltypes.FLOAT` except that
- an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision`
- parameter is accepted, and
- the :paramref:`_sqltypes.Float.precision` parameter is not accepted.
-
- Oracle FLOAT types indicate precision in terms of "binary precision", which
- defaults to 126. For a REAL type, the value is 63. This parameter does not
- cleanly map to a specific number of decimal places but is roughly
- equivalent to the desired number of decimal places divided by 0.3103.
-
- .. versionadded:: 2.0
-
- """
-
- __visit_name__ = "FLOAT"
-
- def __init__(
- self,
- binary_precision=None,
- asdecimal=False,
- decimal_return_scale=None,
- ):
- r"""
- Construct a FLOAT
-
- :param binary_precision: Oracle binary precision value to be rendered
- in DDL. This may be approximated to the number of decimal characters
- using the formula "decimal precision = 0.30103 * binary precision".
- The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126.
-
- :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal`
-
- :param decimal_return_scale: See
- :paramref:`_sqltypes.Float.decimal_return_scale`
-
- """
- super().__init__(
- asdecimal=asdecimal, decimal_return_scale=decimal_return_scale
- )
- self.binary_precision = binary_precision
-
-
-class BINARY_DOUBLE(sqltypes.Float):
- __visit_name__ = "BINARY_DOUBLE"
-
-
-class BINARY_FLOAT(sqltypes.Float):
- __visit_name__ = "BINARY_FLOAT"
-
-
-class BFILE(sqltypes.LargeBinary):
- __visit_name__ = "BFILE"
-
-
-class LONG(sqltypes.Text):
- __visit_name__ = "LONG"
-
-
-class _OracleDateLiteralRender:
- def _literal_processor_datetime(self, dialect):
- def process(value):
- if value is not None:
- if getattr(value, "microsecond", None):
- value = (
- f"""TO_TIMESTAMP"""
- f"""('{value.isoformat().replace("T", " ")}', """
- """'YYYY-MM-DD HH24:MI:SS.FF')"""
- )
- else:
- value = (
- f"""TO_DATE"""
- f"""('{value.isoformat().replace("T", " ")}', """
- """'YYYY-MM-DD HH24:MI:SS')"""
- )
- return value
-
- return process
-
- def _literal_processor_date(self, dialect):
- def process(value):
- if value is not None:
- if getattr(value, "microsecond", None):
- value = (
- f"""TO_TIMESTAMP"""
- f"""('{value.isoformat().split("T")[0]}', """
- """'YYYY-MM-DD')"""
- )
- else:
- value = (
- f"""TO_DATE"""
- f"""('{value.isoformat().split("T")[0]}', """
- """'YYYY-MM-DD')"""
- )
- return value
-
- return process
-
-
-class DATE(_OracleDateLiteralRender, sqltypes.DateTime):
- """Provide the oracle DATE type.
-
- This type has no special Python behavior, except that it subclasses
- :class:`_types.DateTime`; this is to suit the fact that the Oracle
- ``DATE`` type supports a time value.
-
- .. versionadded:: 0.9.4
-
- """
-
- __visit_name__ = "DATE"
-
- def literal_processor(self, dialect):
- return self._literal_processor_datetime(dialect)
-
- def _compare_type_affinity(self, other):
- return other._type_affinity in (sqltypes.DateTime, sqltypes.Date)
-
-
-class _OracleDate(_OracleDateLiteralRender, sqltypes.Date):
- def literal_processor(self, dialect):
- return self._literal_processor_date(dialect)
-
-
-class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
- __visit_name__ = "INTERVAL"
-
- def __init__(self, day_precision=None, second_precision=None):
- """Construct an INTERVAL.
-
- Note that only DAY TO SECOND intervals are currently supported.
- This is due to a lack of support for YEAR TO MONTH intervals
- within available DBAPIs.
-
- :param day_precision: the day precision value. this is the number of
- digits to store for the day field. Defaults to "2"
- :param second_precision: the second precision value. this is the
- number of digits to store for the fractional seconds field.
- Defaults to "6".
-
- """
- self.day_precision = day_precision
- self.second_precision = second_precision
-
- @classmethod
- def _adapt_from_generic_interval(cls, interval):
- return INTERVAL(
- day_precision=interval.day_precision,
- second_precision=interval.second_precision,
- )
-
- @property
- def _type_affinity(self):
- return sqltypes.Interval
-
- def as_generic(self, allow_nulltype=False):
- return sqltypes.Interval(
- native=True,
- second_precision=self.second_precision,
- day_precision=self.day_precision,
- )
-
-
-class ROWID(sqltypes.TypeEngine):
- """Oracle ROWID type.
-
- When used in a cast() or similar, generates ROWID.
-
- """
-
- __visit_name__ = "ROWID"
-
-
-class _OracleBoolean(sqltypes.Boolean):
- def get_dbapi_type(self, dbapi):
- return dbapi.NUMBER
-
-
colspecs = {
sqltypes.Boolean: _OracleBoolean,
sqltypes.Interval: INTERVAL,
@@ -1541,6 +1349,13 @@ class OracleExecutionContext(default.DefaultExecutionContext):
type_,
)
+ def pre_exec(self):
+ if self.statement and "_oracle_dblink" in self.execution_options:
+ self.statement = self.statement.replace(
+ dictionary.DB_LINK_PLACEHOLDER,
+ self.execution_options["_oracle_dblink"],
+ )
+
class OracleDialect(default.DefaultDialect):
name = "oracle"
@@ -1675,6 +1490,10 @@ class OracleDialect(default.DefaultDialect):
# it may work also on versions before the 18
return self.server_version_info and self.server_version_info >= (18,)
+ @property
+ def _supports_except_all(self):
+ return self.server_version_info and self.server_version_info >= (21,)
+
def do_release_savepoint(self, connection, name):
# Oracle does not support RELEASE SAVEPOINT
pass
@@ -1700,45 +1519,99 @@ class OracleDialect(default.DefaultDialect):
except:
return "READ COMMITTED"
- def has_table(self, connection, table_name, schema=None):
+ def _execute_reflection(
+ self, connection, query, dblink, returns_long, params=None
+ ):
+ if dblink and not dblink.startswith("@"):
+ dblink = f"@{dblink}"
+ execution_options = {
+ # handle db links
+ "_oracle_dblink": dblink or "",
+ # override any schema translate map
+ "schema_translate_map": None,
+ }
+
+ if dblink and returns_long:
+ # Oracle seems to error with
+ # "ORA-00997: illegal use of LONG datatype" when returning
+ # LONG columns via a dblink in a query with bind params
+ # This type seems to be very hard to cast into something else
+ # so it seems easier to just use bind param in this case
+ def visit_bindparam(bindparam):
+ bindparam.literal_execute = True
+
+ query = visitors.cloned_traverse(
+ query, {}, {"bindparam": visit_bindparam}
+ )
+ return connection.execute(
+ query, params, execution_options=execution_options
+ )
+
+ @util.memoized_property
+ def _has_table_query(self):
+ # materialized views are returned by all_tables
+ tables = (
+ select(
+ dictionary.all_tables.c.table_name,
+ dictionary.all_tables.c.owner,
+ )
+ .union_all(
+ select(
+ dictionary.all_views.c.view_name.label("table_name"),
+ dictionary.all_views.c.owner,
+ )
+ )
+ .subquery("tables_and_views")
+ )
+
+ query = select(tables.c.table_name).where(
+ tables.c.table_name == bindparam("table_name"),
+ tables.c.owner == bindparam("owner"),
+ )
+ return query
+
+ @reflection.cache
+ def has_table(
+ self, connection, table_name, schema=None, dblink=None, **kw
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
self._ensure_has_table_connection(connection)
if not schema:
schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- """SELECT table_name FROM all_tables
- WHERE table_name = CAST(:name AS VARCHAR2(128))
- AND owner = CAST(:schema_name AS VARCHAR2(128))
- UNION ALL
- SELECT view_name FROM all_views
- WHERE view_name = CAST(:name AS VARCHAR2(128))
- AND owner = CAST(:schema_name AS VARCHAR2(128))
- """
- ),
- dict(
- name=self.denormalize_name(table_name),
- schema_name=self.denormalize_name(schema),
- ),
+ params = {
+ "table_name": self.denormalize_name(table_name),
+ "owner": self.denormalize_name(schema),
+ }
+ cursor = self._execute_reflection(
+ connection,
+ self._has_table_query,
+ dblink,
+ returns_long=False,
+ params=params,
)
- return cursor.first() is not None
+ return bool(cursor.scalar())
- def has_sequence(self, connection, sequence_name, schema=None):
+ @reflection.cache
+ def has_sequence(
+ self, connection, sequence_name, schema=None, dblink=None, **kw
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
if not schema:
schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- "SELECT sequence_name FROM all_sequences "
- "WHERE sequence_name = :name AND "
- "sequence_owner = :schema_name"
- ),
- dict(
- name=self.denormalize_name(sequence_name),
- schema_name=self.denormalize_name(schema),
- ),
+
+ query = select(dictionary.all_sequences.c.sequence_name).where(
+ dictionary.all_sequences.c.sequence_name
+ == self.denormalize_name(sequence_name),
+ dictionary.all_sequences.c.sequence_owner
+ == self.denormalize_name(schema),
)
- return cursor.first() is not None
+
+ cursor = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ )
+ return bool(cursor.scalar())
def _get_default_schema_name(self, connection):
return self.normalize_name(
@@ -1747,329 +1620,633 @@ class OracleDialect(default.DefaultDialect):
).scalar()
)
- def _resolve_synonym(
- self,
- connection,
- desired_owner=None,
- desired_synonym=None,
- desired_table=None,
+ @reflection.flexi_cache(
+ ("schema", InternalTraversal.dp_string),
+ ("filter_names", InternalTraversal.dp_string_list),
+ ("dblink", InternalTraversal.dp_string),
+ )
+ def _get_synonyms(self, connection, schema, filter_names, dblink, **kw):
+ owner = self.denormalize_name(schema or self.default_schema_name)
+
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = select(
+ dictionary.all_synonyms.c.synonym_name,
+ dictionary.all_synonyms.c.table_name,
+ dictionary.all_synonyms.c.table_owner,
+ dictionary.all_synonyms.c.db_link,
+ ).where(dictionary.all_synonyms.c.owner == owner)
+ if has_filter_names:
+ query = query.where(
+ dictionary.all_synonyms.c.synonym_name.in_(
+ params["filter_names"]
+ )
+ )
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).mappings()
+ return result.all()
+
+ @lru_cache()
+ def _all_objects_query(
+ self, owner, scope, kind, has_filter_names, has_mat_views
):
- """search for a local synonym matching the given desired owner/name.
-
- if desired_owner is None, attempts to locate a distinct owner.
-
- returns the actual name, owner, dblink name, and synonym name if
- found.
- """
-
- q = (
- "SELECT owner, table_owner, table_name, db_link, "
- "synonym_name FROM all_synonyms WHERE "
+ query = (
+ select(dictionary.all_objects.c.object_name)
+ .select_from(dictionary.all_objects)
+ .where(dictionary.all_objects.c.owner == owner)
)
- clauses = []
- params = {}
- if desired_synonym:
- clauses.append(
- "synonym_name = CAST(:synonym_name AS VARCHAR2(128))"
+
+ # NOTE: materialized views are listed in all_objects twice;
+ # once as MATERIALIZE VIEW and once as TABLE
+ if kind is ObjectKind.ANY:
+ # materilaized view are listed also as tables so there is no
+ # need to add them to the in_.
+ query = query.where(
+ dictionary.all_objects.c.object_type.in_(("TABLE", "VIEW"))
)
- params["synonym_name"] = desired_synonym
- if desired_owner:
- clauses.append("owner = CAST(:desired_owner AS VARCHAR2(128))")
- params["desired_owner"] = desired_owner
- if desired_table:
- clauses.append("table_name = CAST(:tname AS VARCHAR2(128))")
- params["tname"] = desired_table
-
- q += " AND ".join(clauses)
-
- result = connection.execution_options(future_result=True).execute(
- sql.text(q), params
- )
- if desired_owner:
- row = result.mappings().first()
- if row:
- return (
- row["table_name"],
- row["table_owner"],
- row["db_link"],
- row["synonym_name"],
- )
- else:
- return None, None, None, None
else:
- rows = result.mappings().all()
- if len(rows) > 1:
- raise AssertionError(
- "There are multiple tables visible to the schema, you "
- "must specify owner"
- )
- elif len(rows) == 1:
- row = rows[0]
- return (
- row["table_name"],
- row["table_owner"],
- row["db_link"],
- row["synonym_name"],
- )
- else:
- return None, None, None, None
+ object_type = []
+ if ObjectKind.VIEW in kind:
+ object_type.append("VIEW")
+ if (
+ ObjectKind.MATERIALIZED_VIEW in kind
+ and ObjectKind.TABLE not in kind
+ ):
+ # materilaized view are listed also as tables so there is no
+ # need to add them to the in_ if also selecting tables.
+ object_type.append("MATERIALIZED VIEW")
+ if ObjectKind.TABLE in kind:
+ object_type.append("TABLE")
+ if has_mat_views and ObjectKind.MATERIALIZED_VIEW not in kind:
+ # materialized view are listed also as tables,
+ # so they need to be filtered out
+ # EXCEPT ALL / MINUS profiles as faster than using
+ # NOT EXISTS or NOT IN with a subquery, but it's in
+ # general faster to get the mat view names and exclude
+ # them only when needed
+ query = query.where(
+ dictionary.all_objects.c.object_name.not_in(
+ bindparam("mat_views")
+ )
+ )
+ query = query.where(
+ dictionary.all_objects.c.object_type.in_(object_type)
+ )
- @reflection.cache
- def _prepare_reflection_args(
- self,
- connection,
- table_name,
- schema=None,
- resolve_synonyms=False,
- dblink="",
- **kw,
- ):
+ # handles scope
+ if scope is ObjectScope.DEFAULT:
+ query = query.where(dictionary.all_objects.c.temporary == "N")
+ elif scope is ObjectScope.TEMPORARY:
+ query = query.where(dictionary.all_objects.c.temporary == "Y")
- if resolve_synonyms:
- actual_name, owner, dblink, synonym = self._resolve_synonym(
- connection,
- desired_owner=self.denormalize_name(schema),
- desired_synonym=self.denormalize_name(table_name),
+ if has_filter_names:
+ query = query.where(
+ dictionary.all_objects.c.object_name.in_(
+ bindparam("filter_names")
+ )
)
- else:
- actual_name, owner, dblink, synonym = None, None, None, None
- if not actual_name:
- actual_name = self.denormalize_name(table_name)
-
- if dblink:
- # using user_db_links here since all_db_links appears
- # to have more restricted permissions.
- # https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm
- # will need to hear from more users if we are doing
- # the right thing here. See [ticket:2619]
- owner = connection.scalar(
- sql.text(
- "SELECT username FROM user_db_links " "WHERE db_link=:link"
- ),
- dict(link=dblink),
+ return query
+
+ @reflection.flexi_cache(
+ ("schema", InternalTraversal.dp_string),
+ ("scope", InternalTraversal.dp_plain_obj),
+ ("kind", InternalTraversal.dp_plain_obj),
+ ("filter_names", InternalTraversal.dp_string_list),
+ ("dblink", InternalTraversal.dp_string),
+ )
+ def _get_all_objects(
+ self, connection, schema, scope, kind, filter_names, dblink, **kw
+ ):
+ owner = self.denormalize_name(schema or self.default_schema_name)
+
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ has_mat_views = False
+ if (
+ ObjectKind.TABLE in kind
+ and ObjectKind.MATERIALIZED_VIEW not in kind
+ ):
+ # see note in _all_objects_query
+ mat_views = self.get_materialized_view_names(
+ connection, schema, dblink, _normalize=False, **kw
)
- dblink = "@" + dblink
- elif not owner:
- owner = self.denormalize_name(schema or self.default_schema_name)
+ if mat_views:
+ params["mat_views"] = mat_views
+ has_mat_views = True
+
+ query = self._all_objects_query(
+ owner, scope, kind, has_filter_names, has_mat_views
+ )
- return (actual_name, owner, dblink or "", synonym)
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False, params=params
+ ).scalars()
- @reflection.cache
- def get_schema_names(self, connection, **kw):
- s = "SELECT username FROM all_users ORDER BY username"
- cursor = connection.exec_driver_sql(s)
- return [self.normalize_name(row[0]) for row in cursor]
+ return result.all()
+
+ def _handle_synonyms_decorator(fn):
+ @wraps(fn)
+ def wrapper(self, *args, **kwargs):
+ return self._handle_synonyms(fn, *args, **kwargs)
+
+ return wrapper
+
+ def _handle_synonyms(self, fn, connection, *args, **kwargs):
+ if not kwargs.get("oracle_resolve_synonyms", False):
+ return fn(self, connection, *args, **kwargs)
+
+ original_kw = kwargs.copy()
+ schema = kwargs.pop("schema", None)
+ result = self._get_synonyms(
+ connection,
+ schema=schema,
+ filter_names=kwargs.pop("filter_names", None),
+ dblink=kwargs.pop("dblink", None),
+ info_cache=kwargs.get("info_cache", None),
+ )
+
+ dblinks_owners = defaultdict(dict)
+ for row in result:
+ key = row["db_link"], row["table_owner"]
+ tn = self.normalize_name(row["table_name"])
+ dblinks_owners[key][tn] = row["synonym_name"]
+
+ if not dblinks_owners:
+ # No synonym, do the plain thing
+ return fn(self, connection, *args, **original_kw)
+
+ data = {}
+ for (dblink, table_owner), mapping in dblinks_owners.items():
+ call_kw = {
+ **original_kw,
+ "schema": table_owner,
+ "dblink": self.normalize_name(dblink),
+ "filter_names": mapping.keys(),
+ }
+ call_result = fn(self, connection, *args, **call_kw)
+ for (_, tn), value in call_result:
+ synonym_name = self.normalize_name(mapping[tn])
+ data[(schema, synonym_name)] = value
+ return data.items()
@reflection.cache
- def get_table_names(self, connection, schema=None, **kw):
- schema = self.denormalize_name(schema or self.default_schema_name)
+ def get_schema_names(self, connection, dblink=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
+ query = select(dictionary.all_users.c.username).order_by(
+ dictionary.all_users.c.username
+ )
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(row) for row in result]
+ @reflection.cache
+ def get_table_names(self, connection, schema=None, dblink=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
# note that table_names() isn't loading DBLINKed or synonym'ed tables
if schema is None:
schema = self.default_schema_name
- sql_str = "SELECT table_name FROM all_tables WHERE "
+ den_schema = self.denormalize_name(schema)
+ if kw.get("oracle_resolve_synonyms", False):
+ tables = (
+ select(
+ dictionary.all_tables.c.table_name,
+ dictionary.all_tables.c.owner,
+ dictionary.all_tables.c.iot_name,
+ dictionary.all_tables.c.duration,
+ dictionary.all_tables.c.tablespace_name,
+ )
+ .union_all(
+ select(
+ dictionary.all_synonyms.c.synonym_name.label(
+ "table_name"
+ ),
+ dictionary.all_synonyms.c.owner,
+ dictionary.all_tables.c.iot_name,
+ dictionary.all_tables.c.duration,
+ dictionary.all_tables.c.tablespace_name,
+ )
+ .select_from(dictionary.all_tables)
+ .join(
+ dictionary.all_synonyms,
+ and_(
+ dictionary.all_tables.c.table_name
+ == dictionary.all_synonyms.c.table_name,
+ dictionary.all_tables.c.owner
+ == func.coalesce(
+ dictionary.all_synonyms.c.table_owner,
+ dictionary.all_synonyms.c.owner,
+ ),
+ ),
+ )
+ )
+ .subquery("available_tables")
+ )
+ else:
+ tables = dictionary.all_tables
+
+ query = select(tables.c.table_name)
if self.exclude_tablespaces:
- sql_str += (
- "nvl(tablespace_name, 'no tablespace') "
- "NOT IN (%s) AND "
- % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces]))
+ query = query.where(
+ func.coalesce(
+ tables.c.tablespace_name, "no tablespace"
+ ).not_in(self.exclude_tablespaces)
)
- sql_str += (
- "OWNER = :owner " "AND IOT_NAME IS NULL " "AND DURATION IS NULL"
+ query = query.where(
+ tables.c.owner == den_schema,
+ tables.c.iot_name.is_(null()),
+ tables.c.duration.is_(null()),
)
- cursor = connection.execute(sql.text(sql_str), dict(owner=schema))
- return [self.normalize_name(row[0]) for row in cursor]
+ # remove materialized views
+ mat_query = select(
+ dictionary.all_mviews.c.mview_name.label("table_name")
+ ).where(dictionary.all_mviews.c.owner == den_schema)
+
+ query = (
+ query.except_all(mat_query)
+ if self._supports_except_all
+ else query.except_(mat_query)
+ )
+
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(row) for row in result]
@reflection.cache
- def get_temp_table_names(self, connection, **kw):
+ def get_temp_table_names(self, connection, dblink=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
schema = self.denormalize_name(self.default_schema_name)
- sql_str = "SELECT table_name FROM all_tables WHERE "
+ query = select(dictionary.all_tables.c.table_name)
if self.exclude_tablespaces:
- sql_str += (
- "nvl(tablespace_name, 'no tablespace') "
- "NOT IN (%s) AND "
- % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces]))
+ query = query.where(
+ func.coalesce(
+ dictionary.all_tables.c.tablespace_name, "no tablespace"
+ ).not_in(self.exclude_tablespaces)
)
- sql_str += (
- "OWNER = :owner "
- "AND IOT_NAME IS NULL "
- "AND DURATION IS NOT NULL"
+ query = query.where(
+ dictionary.all_tables.c.owner == schema,
+ dictionary.all_tables.c.iot_name.is_(null()),
+ dictionary.all_tables.c.duration.is_not(null()),
)
- cursor = connection.execute(sql.text(sql_str), dict(owner=schema))
- return [self.normalize_name(row[0]) for row in cursor]
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(row) for row in result]
@reflection.cache
- def get_view_names(self, connection, schema=None, **kw):
- schema = self.denormalize_name(schema or self.default_schema_name)
- s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner")
- cursor = connection.execute(
- s, dict(owner=self.denormalize_name(schema))
+ def get_materialized_view_names(
+ self, connection, schema=None, dblink=None, _normalize=True, **kw
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
+ if not schema:
+ schema = self.default_schema_name
+
+ query = select(dictionary.all_mviews.c.mview_name).where(
+ dictionary.all_mviews.c.owner == self.denormalize_name(schema)
)
- return [self.normalize_name(row[0]) for row in cursor]
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ if _normalize:
+ return [self.normalize_name(row) for row in result]
+ else:
+ return result.all()
@reflection.cache
- def get_sequence_names(self, connection, schema=None, **kw):
+ def get_view_names(self, connection, schema=None, dblink=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
if not schema:
schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- "SELECT sequence_name FROM all_sequences "
- "WHERE sequence_owner = :schema_name"
- ),
- dict(schema_name=self.denormalize_name(schema)),
+
+ query = select(dictionary.all_views.c.view_name).where(
+ dictionary.all_views.c.owner == self.denormalize_name(schema)
)
- return [self.normalize_name(row[0]) for row in cursor]
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(row) for row in result]
@reflection.cache
- def get_table_options(self, connection, table_name, schema=None, **kw):
- options = {}
+ def get_sequence_names(self, connection, schema=None, dblink=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link."""
+ if not schema:
+ schema = self.default_schema_name
+ query = select(dictionary.all_sequences.c.sequence_name).where(
+ dictionary.all_sequences.c.sequence_owner
+ == self.denormalize_name(schema)
+ )
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(row) for row in result]
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ def _value_or_raise(self, data, table, schema):
+ table = self.normalize_name(str(table))
+ try:
+ return dict(data)[(schema, table)]
+ except KeyError:
+ raise exc.NoSuchTableError(
+ f"{schema}.{table}" if schema else table
+ ) from None
+
+ def _prepare_filter_names(self, filter_names):
+ if filter_names:
+ fn = [self.denormalize_name(name) for name in filter_names]
+ return True, {"filter_names": fn}
+ else:
+ return False, {}
+
+ @reflection.cache
+ def get_table_options(self, connection, table_name, schema=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_table_options(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- params = {"table_name": table_name}
+ @lru_cache()
+ def _table_options_query(
+ self, owner, scope, kind, has_filter_names, has_mat_views
+ ):
+ query = select(
+ dictionary.all_tables.c.table_name,
+ dictionary.all_tables.c.compression,
+ dictionary.all_tables.c.compress_for,
+ ).where(dictionary.all_tables.c.owner == owner)
+ if has_filter_names:
+ query = query.where(
+ dictionary.all_tables.c.table_name.in_(
+ bindparam("filter_names")
+ )
+ )
+ if scope is ObjectScope.DEFAULT:
+ query = query.where(dictionary.all_tables.c.duration.is_(null()))
+ elif scope is ObjectScope.TEMPORARY:
+ query = query.where(
+ dictionary.all_tables.c.duration.is_not(null())
+ )
- columns = ["table_name"]
- if self._supports_table_compression:
- columns.append("compression")
- if self._supports_table_compress_for:
- columns.append("compress_for")
+ if (
+ has_mat_views
+ and ObjectKind.TABLE in kind
+ and ObjectKind.MATERIALIZED_VIEW not in kind
+ ):
+ # cant use EXCEPT ALL / MINUS here because we don't have an
+ # excludable row vs. the query above
+ # outerjoin + where null works better on oracle 21 but 11 does
+ # not like it at all. this is the next best thing
+
+ query = query.where(
+ dictionary.all_tables.c.table_name.not_in(
+ bindparam("mat_views")
+ )
+ )
+ elif (
+ ObjectKind.TABLE not in kind
+ and ObjectKind.MATERIALIZED_VIEW in kind
+ ):
+ query = query.where(
+ dictionary.all_tables.c.table_name.in_(bindparam("mat_views"))
+ )
+ return query
- text = (
- "SELECT %(columns)s "
- "FROM ALL_TABLES%(dblink)s "
- "WHERE table_name = CAST(:table_name AS VARCHAR(128))"
- )
+ @_handle_synonyms_decorator
+ def get_multi_table_options(
+ self,
+ connection,
+ *,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ owner = self.denormalize_name(schema or self.default_schema_name)
- if schema is not None:
- params["owner"] = schema
- text += " AND owner = CAST(:owner AS VARCHAR(128)) "
- text = text % {"dblink": dblink, "columns": ", ".join(columns)}
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ has_mat_views = False
- result = connection.execute(sql.text(text), params)
+ if (
+ ObjectKind.TABLE in kind
+ and ObjectKind.MATERIALIZED_VIEW not in kind
+ ):
+ # see note in _table_options_query
+ mat_views = self.get_materialized_view_names(
+ connection, schema, dblink, _normalize=False, **kw
+ )
+ if mat_views:
+ params["mat_views"] = mat_views
+ has_mat_views = True
+ elif (
+ ObjectKind.TABLE not in kind
+ and ObjectKind.MATERIALIZED_VIEW in kind
+ ):
+ mat_views = self.get_materialized_view_names(
+ connection, schema, dblink, _normalize=False, **kw
+ )
+ params["mat_views"] = mat_views
- enabled = dict(DISABLED=False, ENABLED=True)
+ options = {}
+ default = ReflectionDefaults.table_options
- row = result.first()
- if row:
- if "compression" in row._fields and enabled.get(
- row.compression, False
- ):
- if "compress_for" in row._fields:
- options["oracle_compress"] = row.compress_for
+ if ObjectKind.TABLE in kind or ObjectKind.MATERIALIZED_VIEW in kind:
+ query = self._table_options_query(
+ owner, scope, kind, has_filter_names, has_mat_views
+ )
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False, params=params
+ )
+
+ for table, compression, compress_for in result:
+ if compression == "ENABLED":
+ data = {"oracle_compress": compress_for}
else:
- options["oracle_compress"] = True
+ data = default()
+ options[(schema, self.normalize_name(table))] = data
+ if ObjectKind.VIEW in kind and ObjectScope.DEFAULT in scope:
+ # add the views (no temporary views)
+ for view in self.get_view_names(connection, schema, dblink, **kw):
+ if not filter_names or view in filter_names:
+ options[(schema, view)] = default()
- return options
+ return options.items()
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
"""
- kw arguments can be:
+ data = self.get_multi_columns(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
+ )
+ return self._value_or_raise(data, table_name, schema)
+
+ def _run_batches(
+ self, connection, query, dblink, returns_long, mappings, all_objects
+ ):
+ each_batch = 500
+ batches = list(all_objects)
+ while batches:
+ batch = batches[0:each_batch]
+ batches[0:each_batch] = []
+
+ result = self._execute_reflection(
+ connection,
+ query,
+ dblink,
+ returns_long=returns_long,
+ params={"all_objects": batch},
+ )
+ if mappings:
+ yield from result.mappings()
+ else:
+ yield from result
+
+ @lru_cache()
+ def _column_query(self, owner):
+ all_cols = dictionary.all_tab_cols
+ all_comments = dictionary.all_col_comments
+ all_ids = dictionary.all_tab_identity_cols
- oracle_resolve_synonyms
+ if self.server_version_info >= (12,):
+ add_cols = (
+ all_cols.c.default_on_null,
+ sql.case(
+ (all_ids.c.table_name.is_(None), sql.null()),
+ else_=all_ids.c.generation_type
+ + ","
+ + all_ids.c.identity_options,
+ ).label("identity_options"),
+ )
+ join_identity_cols = True
+ else:
+ add_cols = (
+ sql.null().label("default_on_null"),
+ sql.null().label("identity_options"),
+ )
+ join_identity_cols = False
+
+ # NOTE: on oracle cannot create tables/views without columns and
+ # a table cannot have all column hidden:
+ # ORA-54039: table must have at least one column that is not invisible
+ # all_tab_cols returns data for tables/views/mat-views.
+ # all_tab_cols does not return recycled tables
+
+ query = (
+ select(
+ all_cols.c.table_name,
+ all_cols.c.column_name,
+ all_cols.c.data_type,
+ all_cols.c.char_length,
+ all_cols.c.data_precision,
+ all_cols.c.data_scale,
+ all_cols.c.nullable,
+ all_cols.c.data_default,
+ all_comments.c.comments,
+ all_cols.c.virtual_column,
+ *add_cols,
+ ).select_from(all_cols)
+ # NOTE: all_col_comments has a row for each column even if no
+ # comment is present, so a join could be performed, but there
+ # seems to be no difference compared to an outer join
+ .outerjoin(
+ all_comments,
+ and_(
+ all_cols.c.table_name == all_comments.c.table_name,
+ all_cols.c.column_name == all_comments.c.column_name,
+ all_cols.c.owner == all_comments.c.owner,
+ ),
+ )
+ )
+ if join_identity_cols:
+ query = query.outerjoin(
+ all_ids,
+ and_(
+ all_cols.c.table_name == all_ids.c.table_name,
+ all_cols.c.column_name == all_ids.c.column_name,
+ all_cols.c.owner == all_ids.c.owner,
+ ),
+ )
- dblink
+ query = query.where(
+ all_cols.c.table_name.in_(bindparam("all_objects")),
+ all_cols.c.hidden_column == "NO",
+ all_cols.c.owner == owner,
+ ).order_by(all_cols.c.table_name, all_cols.c.column_id)
+ return query
+ @_handle_synonyms_decorator
+ def get_multi_columns(
+ self,
+ connection,
+ *,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
"""
+ owner = self.denormalize_name(schema or self.default_schema_name)
+ query = self._column_query(owner)
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
+ if (
+ filter_names
+ and kind is ObjectKind.ANY
+ and scope is ObjectScope.ANY
+ ):
+ all_objects = [self.denormalize_name(n) for n in filter_names]
+ else:
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
+ )
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ columns = defaultdict(list)
+
+ # all_tab_cols.data_default is LONG
+ result = self._run_batches(
connection,
- table_name,
- schema,
- resolve_synonyms,
+ query,
dblink,
- info_cache=info_cache,
+ returns_long=True,
+ mappings=True,
+ all_objects=all_objects,
)
- columns = []
- if self._supports_char_length:
- char_length_col = "char_length"
- else:
- char_length_col = "data_length"
- if self.server_version_info >= (12,):
- identity_cols = """\
- col.default_on_null,
- (
- SELECT id.generation_type || ',' || id.IDENTITY_OPTIONS
- FROM ALL_TAB_IDENTITY_COLS%(dblink)s id
- WHERE col.table_name = id.table_name
- AND col.column_name = id.column_name
- AND col.owner = id.owner
- ) AS identity_options""" % {
- "dblink": dblink
- }
- else:
- identity_cols = "NULL as default_on_null, NULL as identity_options"
-
- params = {"table_name": table_name}
-
- text = """
- SELECT
- col.column_name,
- col.data_type,
- col.%(char_length_col)s,
- col.data_precision,
- col.data_scale,
- col.nullable,
- col.data_default,
- com.comments,
- col.virtual_column,
- %(identity_cols)s
- FROM all_tab_cols%(dblink)s col
- LEFT JOIN all_col_comments%(dblink)s com
- ON col.table_name = com.table_name
- AND col.column_name = com.column_name
- AND col.owner = com.owner
- WHERE col.table_name = CAST(:table_name AS VARCHAR2(128))
- AND col.hidden_column = 'NO'
- """
- if schema is not None:
- params["owner"] = schema
- text += " AND col.owner = :owner "
- text += " ORDER BY col.column_id"
- text = text % {
- "dblink": dblink,
- "char_length_col": char_length_col,
- "identity_cols": identity_cols,
- }
-
- c = connection.execute(sql.text(text), params)
-
- for row in c:
- colname = self.normalize_name(row[0])
- orig_colname = row[0]
- coltype = row[1]
- length = row[2]
- precision = row[3]
- scale = row[4]
- nullable = row[5] == "Y"
- default = row[6]
- comment = row[7]
- generated = row[8]
- default_on_nul = row[9]
- identity_options = row[10]
+ for row_dict in result:
+ table_name = self.normalize_name(row_dict["table_name"])
+ orig_colname = row_dict["column_name"]
+ colname = self.normalize_name(orig_colname)
+ coltype = row_dict["data_type"]
+ precision = row_dict["data_precision"]
if coltype == "NUMBER":
+ scale = row_dict["data_scale"]
if precision is None and scale == 0:
coltype = INTEGER()
else:
@@ -2089,7 +2266,9 @@ class OracleDialect(default.DefaultDialect):
coltype = FLOAT(binary_precision=precision)
elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR", "NCHAR"):
- coltype = self.ischema_names.get(coltype)(length)
+ coltype = self.ischema_names.get(coltype)(
+ row_dict["char_length"]
+ )
elif "WITH TIME ZONE" in coltype:
coltype = TIMESTAMP(timezone=True)
else:
@@ -2103,15 +2282,17 @@ class OracleDialect(default.DefaultDialect):
)
coltype = sqltypes.NULLTYPE
- if generated == "YES":
+ default = row_dict["data_default"]
+ if row_dict["virtual_column"] == "YES":
computed = dict(sqltext=default)
default = None
else:
computed = None
+ identity_options = row_dict["identity_options"]
if identity_options is not None:
identity = self._parse_identity_options(
- identity_options, default_on_nul
+ identity_options, row_dict["default_on_null"]
)
default = None
else:
@@ -2120,10 +2301,9 @@ class OracleDialect(default.DefaultDialect):
cdict = {
"name": colname,
"type": coltype,
- "nullable": nullable,
+ "nullable": row_dict["nullable"] == "Y",
"default": default,
- "autoincrement": "auto",
- "comment": comment,
+ "comment": row_dict["comments"],
}
if orig_colname.lower() == orig_colname:
cdict["quote"] = True
@@ -2132,10 +2312,17 @@ class OracleDialect(default.DefaultDialect):
if identity is not None:
cdict["identity"] = identity
- columns.append(cdict)
- return columns
+ columns[(schema, table_name)].append(cdict)
- def _parse_identity_options(self, identity_options, default_on_nul):
+ # NOTE: default not needed since all tables have columns
+ # default = ReflectionDefaults.columns
+ # return (
+ # (key, value if value else default())
+ # for key, value in columns.items()
+ # )
+ return columns.items()
+
+ def _parse_identity_options(self, identity_options, default_on_null):
# identity_options is a string that starts with 'ALWAYS,' or
# 'BY DEFAULT,' and continues with
# START WITH: 1, INCREMENT BY: 1, MAX_VALUE: 123, MIN_VALUE: 1,
@@ -2144,7 +2331,7 @@ class OracleDialect(default.DefaultDialect):
parts = [p.strip() for p in identity_options.split(",")]
identity = {
"always": parts[0] == "ALWAYS",
- "on_null": default_on_nul == "YES",
+ "on_null": default_on_null == "YES",
}
for part in parts[1:]:
@@ -2168,384 +2355,641 @@ class OracleDialect(default.DefaultDialect):
return identity
@reflection.cache
- def get_table_comment(
+ def get_table_comment(self, connection, table_name, schema=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_table_comment(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
+ )
+ return self._value_or_raise(data, table_name, schema)
+
+ @lru_cache()
+ def _comment_query(self, owner, scope, kind, has_filter_names):
+ # NOTE: all_tab_comments / all_mview_comments have a row for all
+ # object even if they don't have comments
+ queries = []
+ if ObjectKind.TABLE in kind or ObjectKind.VIEW in kind:
+ # all_tab_comments returns also plain views
+ tbl_view = select(
+ dictionary.all_tab_comments.c.table_name,
+ dictionary.all_tab_comments.c.comments,
+ ).where(
+ dictionary.all_tab_comments.c.owner == owner,
+ dictionary.all_tab_comments.c.table_name.not_like("BIN$%"),
+ )
+ if ObjectKind.VIEW not in kind:
+ tbl_view = tbl_view.where(
+ dictionary.all_tab_comments.c.table_type == "TABLE"
+ )
+ elif ObjectKind.TABLE not in kind:
+ tbl_view = tbl_view.where(
+ dictionary.all_tab_comments.c.table_type == "VIEW"
+ )
+ queries.append(tbl_view)
+ if ObjectKind.MATERIALIZED_VIEW in kind:
+ mat_view = select(
+ dictionary.all_mview_comments.c.mview_name.label("table_name"),
+ dictionary.all_mview_comments.c.comments,
+ ).where(
+ dictionary.all_mview_comments.c.owner == owner,
+ dictionary.all_mview_comments.c.mview_name.not_like("BIN$%"),
+ )
+ queries.append(mat_view)
+ if len(queries) == 1:
+ query = queries[0]
+ else:
+ union = sql.union_all(*queries).subquery("tables_and_views")
+ query = select(union.c.table_name, union.c.comments)
+
+ name_col = query.selected_columns.table_name
+
+ if scope in (ObjectScope.DEFAULT, ObjectScope.TEMPORARY):
+ temp = "Y" if scope is ObjectScope.TEMPORARY else "N"
+ # need distinct since materialized view are listed also
+ # as tables in all_objects
+ query = query.distinct().join(
+ dictionary.all_objects,
+ and_(
+ dictionary.all_objects.c.owner == owner,
+ dictionary.all_objects.c.object_name == name_col,
+ dictionary.all_objects.c.temporary == temp,
+ ),
+ )
+ if has_filter_names:
+ query = query.where(name_col.in_(bindparam("filter_names")))
+ return query
+
+ @_handle_synonyms_decorator
+ def get_multi_table_comment(
self,
connection,
- table_name,
- schema=None,
- resolve_synonyms=False,
- dblink="",
+ *,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ dblink=None,
**kw,
):
-
- info_cache = kw.get("info_cache")
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
- connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
- )
-
- if not schema:
- schema = self.default_schema_name
-
- COMMENT_SQL = """
- SELECT comments
- FROM all_tab_comments
- WHERE table_name = CAST(:table_name AS VARCHAR(128))
- AND owner = CAST(:schema_name AS VARCHAR(128))
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
"""
+ owner = self.denormalize_name(schema or self.default_schema_name)
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._comment_query(owner, scope, kind, has_filter_names)
- c = connection.execute(
- sql.text(COMMENT_SQL),
- dict(table_name=table_name, schema_name=schema),
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False, params=params
+ )
+ default = ReflectionDefaults.table_comment
+ # materialized views by default seem to have a comment like
+ # "snapshot table for snapshot owner.mat_view_name"
+ ignore_mat_view = "snapshot table for snapshot "
+ return (
+ (
+ (schema, self.normalize_name(table)),
+ {"text": comment}
+ if comment is not None
+ and not comment.startswith(ignore_mat_view)
+ else default(),
+ )
+ for table, comment in result
)
- return {"text": c.scalar()}
@reflection.cache
- def get_indexes(
- self,
- connection,
- table_name,
- schema=None,
- resolve_synonyms=False,
- dblink="",
- **kw,
- ):
-
- info_cache = kw.get("info_cache")
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_indexes(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
- indexes = []
-
- params = {"table_name": table_name}
- text = (
- "SELECT a.index_name, a.column_name, "
- "\nb.index_type, b.uniqueness, b.compression, b.prefix_length "
- "\nFROM ALL_IND_COLUMNS%(dblink)s a, "
- "\nALL_INDEXES%(dblink)s b "
- "\nWHERE "
- "\na.index_name = b.index_name "
- "\nAND a.table_owner = b.table_owner "
- "\nAND a.table_name = b.table_name "
- "\nAND a.table_name = CAST(:table_name AS VARCHAR(128))"
+ return self._value_or_raise(data, table_name, schema)
+
+ @lru_cache()
+ def _index_query(self, owner):
+ return (
+ select(
+ dictionary.all_ind_columns.c.table_name,
+ dictionary.all_ind_columns.c.index_name,
+ dictionary.all_ind_columns.c.column_name,
+ dictionary.all_indexes.c.index_type,
+ dictionary.all_indexes.c.uniqueness,
+ dictionary.all_indexes.c.compression,
+ dictionary.all_indexes.c.prefix_length,
+ )
+ .select_from(dictionary.all_ind_columns)
+ .join(
+ dictionary.all_indexes,
+ sql.and_(
+ dictionary.all_ind_columns.c.index_name
+ == dictionary.all_indexes.c.index_name,
+ dictionary.all_ind_columns.c.table_owner
+ == dictionary.all_indexes.c.table_owner,
+ # NOTE: this condition on table_name is not required
+ # but it improves the query performance noticeably
+ dictionary.all_ind_columns.c.table_name
+ == dictionary.all_indexes.c.table_name,
+ ),
+ )
+ .where(
+ dictionary.all_ind_columns.c.table_owner == owner,
+ dictionary.all_ind_columns.c.table_name.in_(
+ bindparam("all_objects")
+ ),
+ )
+ .order_by(
+ dictionary.all_ind_columns.c.index_name,
+ dictionary.all_ind_columns.c.column_position,
+ )
)
- if schema is not None:
- params["schema"] = schema
- text += "AND a.table_owner = :schema "
+ @reflection.flexi_cache(
+ ("schema", InternalTraversal.dp_string),
+ ("dblink", InternalTraversal.dp_string),
+ ("all_objects", InternalTraversal.dp_string_list),
+ )
+ def _get_indexes_rows(self, connection, schema, dblink, all_objects, **kw):
+ owner = self.denormalize_name(schema or self.default_schema_name)
- text += "ORDER BY a.index_name, a.column_position"
+ query = self._index_query(owner)
- text = text % {"dblink": dblink}
+ pks = {
+ row_dict["constraint_name"]
+ for row_dict in self._get_all_constraint_rows(
+ connection, schema, dblink, all_objects, **kw
+ )
+ if row_dict["constraint_type"] == "P"
+ }
- q = sql.text(text)
- rp = connection.execute(q, params)
- indexes = []
- last_index_name = None
- pk_constraint = self.get_pk_constraint(
+ result = self._run_batches(
connection,
- table_name,
- schema,
- resolve_synonyms=resolve_synonyms,
- dblink=dblink,
- info_cache=kw.get("info_cache"),
+ query,
+ dblink,
+ returns_long=False,
+ mappings=True,
+ all_objects=all_objects,
)
- uniqueness = dict(NONUNIQUE=False, UNIQUE=True)
- enabled = dict(DISABLED=False, ENABLED=True)
+ return [
+ row_dict
+ for row_dict in result
+ if row_dict["index_name"] not in pks
+ ]
- oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE)
+ @_handle_synonyms_decorator
+ def get_multi_indexes(
+ self,
+ connection,
+ *,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
+ )
- index = None
- for rset in rp:
- index_name_normalized = self.normalize_name(rset.index_name)
+ uniqueness = {"NONUNIQUE": False, "UNIQUE": True}
+ enabled = {"DISABLED": False, "ENABLED": True}
+ is_bitmap = {"BITMAP", "FUNCTION-BASED BITMAP"}
- # skip primary key index. This is refined as of
- # [ticket:5421]. Note that ALL_INDEXES.GENERATED will by "Y"
- # if the name of this index was generated by Oracle, however
- # if a named primary key constraint was created then this flag
- # is false.
- if (
- pk_constraint
- and index_name_normalized == pk_constraint["name"]
- ):
- continue
+ oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE)
- if rset.index_name != last_index_name:
- index = dict(
- name=index_name_normalized,
- column_names=[],
- dialect_options={},
- )
- indexes.append(index)
- index["unique"] = uniqueness.get(rset.uniqueness, False)
+ indexes = defaultdict(dict)
+
+ for row_dict in self._get_indexes_rows(
+ connection, schema, dblink, all_objects, **kw
+ ):
+ index_name = self.normalize_name(row_dict["index_name"])
+ table_name = self.normalize_name(row_dict["table_name"])
+ table_indexes = indexes[(schema, table_name)]
+
+ if index_name not in table_indexes:
+ table_indexes[index_name] = index_dict = {
+ "name": index_name,
+ "column_names": [],
+ "dialect_options": {},
+ "unique": uniqueness.get(row_dict["uniqueness"], False),
+ }
+ do = index_dict["dialect_options"]
+ if row_dict["index_type"] in is_bitmap:
+ do["oracle_bitmap"] = True
+ if enabled.get(row_dict["compression"], False):
+ do["oracle_compress"] = row_dict["prefix_length"]
- if rset.index_type in ("BITMAP", "FUNCTION-BASED BITMAP"):
- index["dialect_options"]["oracle_bitmap"] = True
- if enabled.get(rset.compression, False):
- index["dialect_options"][
- "oracle_compress"
- ] = rset.prefix_length
+ else:
+ index_dict = table_indexes[index_name]
# filter out Oracle SYS_NC names. could also do an outer join
- # to the all_tab_columns table and check for real col names there.
- if not oracle_sys_col.match(rset.column_name):
- index["column_names"].append(
- self.normalize_name(rset.column_name)
+ # to the all_tab_columns table and check for real col names
+ # there.
+ if not oracle_sys_col.match(row_dict["column_name"]):
+ index_dict["column_names"].append(
+ self.normalize_name(row_dict["column_name"])
)
- last_index_name = rset.index_name
- return indexes
+ default = ReflectionDefaults.indexes
- @reflection.cache
- def _get_constraint_data(
- self, connection, table_name, schema=None, dblink="", **kw
- ):
-
- params = {"table_name": table_name}
-
- text = (
- "SELECT"
- "\nac.constraint_name," # 0
- "\nac.constraint_type," # 1
- "\nloc.column_name AS local_column," # 2
- "\nrem.table_name AS remote_table," # 3
- "\nrem.column_name AS remote_column," # 4
- "\nrem.owner AS remote_owner," # 5
- "\nloc.position as loc_pos," # 6
- "\nrem.position as rem_pos," # 7
- "\nac.search_condition," # 8
- "\nac.delete_rule" # 9
- "\nFROM all_constraints%(dblink)s ac,"
- "\nall_cons_columns%(dblink)s loc,"
- "\nall_cons_columns%(dblink)s rem"
- "\nWHERE ac.table_name = CAST(:table_name AS VARCHAR2(128))"
- "\nAND ac.constraint_type IN ('R','P', 'U', 'C')"
- )
-
- if schema is not None:
- params["owner"] = schema
- text += "\nAND ac.owner = CAST(:owner AS VARCHAR2(128))"
-
- text += (
- "\nAND ac.owner = loc.owner"
- "\nAND ac.constraint_name = loc.constraint_name"
- "\nAND ac.r_owner = rem.owner(+)"
- "\nAND ac.r_constraint_name = rem.constraint_name(+)"
- "\nAND (rem.position IS NULL or loc.position=rem.position)"
- "\nORDER BY ac.constraint_name, loc.position"
+ return (
+ (key, list(indexes[key].values()) if key in indexes else default())
+ for key in (
+ (schema, self.normalize_name(obj_name))
+ for obj_name in all_objects
+ )
)
- text = text % {"dblink": dblink}
- rp = connection.execute(sql.text(text), params)
- constraint_data = rp.fetchall()
- return constraint_data
-
@reflection.cache
def get_pk_constraint(self, connection, table_name, schema=None, **kw):
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
-
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_pk_constraint(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
- pkeys = []
- constraint_name = None
- constraint_data = self._get_constraint_data(
- connection,
- table_name,
- schema,
- dblink,
- info_cache=kw.get("info_cache"),
+ return self._value_or_raise(data, table_name, schema)
+
+ @lru_cache()
+ def _constraint_query(self, owner):
+ local = dictionary.all_cons_columns.alias("local")
+ remote = dictionary.all_cons_columns.alias("remote")
+ return (
+ select(
+ dictionary.all_constraints.c.table_name,
+ dictionary.all_constraints.c.constraint_type,
+ dictionary.all_constraints.c.constraint_name,
+ local.c.column_name.label("local_column"),
+ remote.c.table_name.label("remote_table"),
+ remote.c.column_name.label("remote_column"),
+ remote.c.owner.label("remote_owner"),
+ dictionary.all_constraints.c.search_condition,
+ dictionary.all_constraints.c.delete_rule,
+ )
+ .select_from(dictionary.all_constraints)
+ .join(
+ local,
+ and_(
+ local.c.owner == dictionary.all_constraints.c.owner,
+ dictionary.all_constraints.c.constraint_name
+ == local.c.constraint_name,
+ ),
+ )
+ .outerjoin(
+ remote,
+ and_(
+ dictionary.all_constraints.c.r_owner == remote.c.owner,
+ dictionary.all_constraints.c.r_constraint_name
+ == remote.c.constraint_name,
+ or_(
+ remote.c.position.is_(sql.null()),
+ local.c.position == remote.c.position,
+ ),
+ ),
+ )
+ .where(
+ dictionary.all_constraints.c.owner == owner,
+ dictionary.all_constraints.c.table_name.in_(
+ bindparam("all_objects")
+ ),
+ dictionary.all_constraints.c.constraint_type.in_(
+ ("R", "P", "U", "C")
+ ),
+ )
+ .order_by(
+ dictionary.all_constraints.c.constraint_name, local.c.position
+ )
)
- for row in constraint_data:
- (
- cons_name,
- cons_type,
- local_column,
- remote_table,
- remote_column,
- remote_owner,
- ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
- if cons_type == "P":
- if constraint_name is None:
- constraint_name = self.normalize_name(cons_name)
- pkeys.append(local_column)
- return {"constrained_columns": pkeys, "name": constraint_name}
+ @reflection.flexi_cache(
+ ("schema", InternalTraversal.dp_string),
+ ("dblink", InternalTraversal.dp_string),
+ ("all_objects", InternalTraversal.dp_string_list),
+ )
+ def _get_all_constraint_rows(
+ self, connection, schema, dblink, all_objects, **kw
+ ):
+ owner = self.denormalize_name(schema or self.default_schema_name)
+ query = self._constraint_query(owner)
- @reflection.cache
- def get_foreign_keys(self, connection, table_name, schema=None, **kw):
+ # since the result is cached a list must be created
+ values = list(
+ self._run_batches(
+ connection,
+ query,
+ dblink,
+ returns_long=False,
+ mappings=True,
+ all_objects=all_objects,
+ )
+ )
+ return values
+
+ @_handle_synonyms_decorator
+ def get_multi_pk_constraint(
+ self,
+ connection,
+ *,
+ scope,
+ schema,
+ filter_names,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
"""
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
+ )
- kw arguments can be:
+ primary_keys = defaultdict(dict)
+ default = ReflectionDefaults.pk_constraint
- oracle_resolve_synonyms
+ for row_dict in self._get_all_constraint_rows(
+ connection, schema, dblink, all_objects, **kw
+ ):
+ if row_dict["constraint_type"] != "P":
+ continue
+ table_name = self.normalize_name(row_dict["table_name"])
+ constraint_name = self.normalize_name(row_dict["constraint_name"])
+ column_name = self.normalize_name(row_dict["local_column"])
+
+ table_pk = primary_keys[(schema, table_name)]
+ if not table_pk:
+ table_pk["name"] = constraint_name
+ table_pk["constrained_columns"] = [column_name]
+ else:
+ table_pk["constrained_columns"].append(column_name)
- dblink
+ return (
+ (key, primary_keys[key] if key in primary_keys else default())
+ for key in (
+ (schema, self.normalize_name(obj_name))
+ for obj_name in all_objects
+ )
+ )
+ @reflection.cache
+ def get_foreign_keys(
+ self,
+ connection,
+ table_name,
+ schema=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
"""
- requested_schema = schema # to check later on
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
-
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ data = self.get_multi_foreign_keys(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- constraint_data = self._get_constraint_data(
- connection,
- table_name,
- schema,
- dblink,
- info_cache=kw.get("info_cache"),
+ @_handle_synonyms_decorator
+ def get_multi_foreign_keys(
+ self,
+ connection,
+ *,
+ scope,
+ schema,
+ filter_names,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
)
- def fkey_rec():
- return {
- "name": None,
- "constrained_columns": [],
- "referred_schema": None,
- "referred_table": None,
- "referred_columns": [],
- "options": {},
- }
+ resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- fkeys = util.defaultdict(fkey_rec)
+ owner = self.denormalize_name(schema or self.default_schema_name)
- for row in constraint_data:
- (
- cons_name,
- cons_type,
- local_column,
- remote_table,
- remote_column,
- remote_owner,
- ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]])
-
- cons_name = self.normalize_name(cons_name)
-
- if cons_type == "R":
- if remote_table is None:
- # ticket 363
- util.warn(
- (
- "Got 'None' querying 'table_name' from "
- "all_cons_columns%(dblink)s - does the user have "
- "proper rights to the table?"
- )
- % {"dblink": dblink}
- )
- continue
+ all_remote_owners = set()
+ fkeys = defaultdict(dict)
+
+ for row_dict in self._get_all_constraint_rows(
+ connection, schema, dblink, all_objects, **kw
+ ):
+ if row_dict["constraint_type"] != "R":
+ continue
+
+ table_name = self.normalize_name(row_dict["table_name"])
+ constraint_name = self.normalize_name(row_dict["constraint_name"])
+ table_fkey = fkeys[(schema, table_name)]
+
+ assert constraint_name is not None
- rec = fkeys[cons_name]
- rec["name"] = cons_name
- local_cols, remote_cols = (
- rec["constrained_columns"],
- rec["referred_columns"],
+ local_column = self.normalize_name(row_dict["local_column"])
+ remote_table = self.normalize_name(row_dict["remote_table"])
+ remote_column = self.normalize_name(row_dict["remote_column"])
+ remote_owner_orig = row_dict["remote_owner"]
+ remote_owner = self.normalize_name(remote_owner_orig)
+ if remote_owner_orig is not None:
+ all_remote_owners.add(remote_owner_orig)
+
+ if remote_table is None:
+ # ticket 363
+ if dblink and not dblink.startswith("@"):
+ dblink = f"@{dblink}"
+ util.warn(
+ "Got 'None' querying 'table_name' from "
+ f"all_cons_columns{dblink or ''} - does the user have "
+ "proper rights to the table?"
)
+ continue
- if not rec["referred_table"]:
- if resolve_synonyms:
- (
- ref_remote_name,
- ref_remote_owner,
- ref_dblink,
- ref_synonym,
- ) = self._resolve_synonym(
- connection,
- desired_owner=self.denormalize_name(remote_owner),
- desired_table=self.denormalize_name(remote_table),
- )
- if ref_synonym:
- remote_table = self.normalize_name(ref_synonym)
- remote_owner = self.normalize_name(
- ref_remote_owner
- )
+ if constraint_name not in table_fkey:
+ table_fkey[constraint_name] = fkey = {
+ "name": constraint_name,
+ "constrained_columns": [],
+ "referred_schema": None,
+ "referred_table": remote_table,
+ "referred_columns": [],
+ "options": {},
+ }
- rec["referred_table"] = remote_table
+ if resolve_synonyms:
+ # will be removed below
+ fkey["_ref_schema"] = remote_owner
- if (
- requested_schema is not None
- or self.denormalize_name(remote_owner) != schema
- ):
- rec["referred_schema"] = remote_owner
+ if schema is not None or remote_owner_orig != owner:
+ fkey["referred_schema"] = remote_owner
+
+ delete_rule = row_dict["delete_rule"]
+ if delete_rule != "NO ACTION":
+ fkey["options"]["ondelete"] = delete_rule
+
+ else:
+ fkey = table_fkey[constraint_name]
+
+ fkey["constrained_columns"].append(local_column)
+ fkey["referred_columns"].append(remote_column)
+
+ if resolve_synonyms and all_remote_owners:
+ query = select(
+ dictionary.all_synonyms.c.owner,
+ dictionary.all_synonyms.c.table_name,
+ dictionary.all_synonyms.c.table_owner,
+ dictionary.all_synonyms.c.synonym_name,
+ ).where(dictionary.all_synonyms.c.owner.in_(all_remote_owners))
+
+ result = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).mappings()
- if row[9] != "NO ACTION":
- rec["options"]["ondelete"] = row[9]
+ remote_owners_lut = {}
+ for row in result:
+ synonym_owner = self.normalize_name(row["owner"])
+ table_name = self.normalize_name(row["table_name"])
- local_cols.append(local_column)
- remote_cols.append(remote_column)
+ remote_owners_lut[(synonym_owner, table_name)] = (
+ row["table_owner"],
+ row["synonym_name"],
+ )
+
+ empty = (None, None)
+ for table_fkeys in fkeys.values():
+ for table_fkey in table_fkeys.values():
+ key = (
+ table_fkey.pop("_ref_schema"),
+ table_fkey["referred_table"],
+ )
+ remote_owner, syn_name = remote_owners_lut.get(key, empty)
+ if syn_name:
+ sn = self.normalize_name(syn_name)
+ table_fkey["referred_table"] = sn
+ if schema is not None or remote_owner != owner:
+ ro = self.normalize_name(remote_owner)
+ table_fkey["referred_schema"] = ro
+ else:
+ table_fkey["referred_schema"] = None
+ default = ReflectionDefaults.foreign_keys
- return list(fkeys.values())
+ return (
+ (key, list(fkeys[key].values()) if key in fkeys else default())
+ for key in (
+ (schema, self.normalize_name(obj_name))
+ for obj_name in all_objects
+ )
+ )
@reflection.cache
def get_unique_constraints(
self, connection, table_name, schema=None, **kw
):
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
-
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_unique_constraints(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- constraint_data = self._get_constraint_data(
- connection,
- table_name,
- schema,
- dblink,
- info_cache=kw.get("info_cache"),
+ @_handle_synonyms_decorator
+ def get_multi_unique_constraints(
+ self,
+ connection,
+ *,
+ scope,
+ schema,
+ filter_names,
+ kind,
+ dblink=None,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
)
- unique_keys = filter(lambda x: x[1] == "U", constraint_data)
- uniques_group = groupby(unique_keys, lambda x: x[0])
+ unique_cons = defaultdict(dict)
index_names = {
- ix["name"]
- for ix in self.get_indexes(connection, table_name, schema=schema)
+ row_dict["index_name"]
+ for row_dict in self._get_indexes_rows(
+ connection, schema, dblink, all_objects, **kw
+ )
}
- return [
- {
- "name": name,
- "column_names": cols,
- "duplicates_index": name if name in index_names else None,
- }
- for name, cols in [
- [
- self.normalize_name(i[0]),
- [self.normalize_name(x[2]) for x in i[1]],
- ]
- for i in uniques_group
- ]
- ]
+
+ for row_dict in self._get_all_constraint_rows(
+ connection, schema, dblink, all_objects, **kw
+ ):
+ if row_dict["constraint_type"] != "U":
+ continue
+ table_name = self.normalize_name(row_dict["table_name"])
+ constraint_name_orig = row_dict["constraint_name"]
+ constraint_name = self.normalize_name(constraint_name_orig)
+ column_name = self.normalize_name(row_dict["local_column"])
+ table_uc = unique_cons[(schema, table_name)]
+
+ assert constraint_name is not None
+
+ if constraint_name not in table_uc:
+ table_uc[constraint_name] = uc = {
+ "name": constraint_name,
+ "column_names": [],
+ "duplicates_index": constraint_name
+ if constraint_name_orig in index_names
+ else None,
+ }
+ else:
+ uc = table_uc[constraint_name]
+
+ uc["column_names"].append(column_name)
+
+ default = ReflectionDefaults.unique_constraints
+
+ return (
+ (
+ key,
+ list(unique_cons[key].values())
+ if key in unique_cons
+ else default(),
+ )
+ for key in (
+ (schema, self.normalize_name(obj_name))
+ for obj_name in all_objects
+ )
+ )
@reflection.cache
def get_view_definition(
@@ -2553,65 +2997,129 @@ class OracleDialect(default.DefaultDialect):
connection,
view_name,
schema=None,
- resolve_synonyms=False,
- dblink="",
+ dblink=None,
**kw,
):
- info_cache = kw.get("info_cache")
- (view_name, schema, dblink, synonym) = self._prepare_reflection_args(
- connection,
- view_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ if kw.get("oracle_resolve_synonyms", False):
+ synonyms = self._get_synonyms(
+ connection, schema, filter_names=[view_name], dblink=dblink
+ )
+ if synonyms:
+ assert len(synonyms) == 1
+ row_dict = synonyms[0]
+ dblink = self.normalize_name(row_dict["db_link"])
+ schema = row_dict["table_owner"]
+ view_name = row_dict["table_name"]
+
+ name = self.denormalize_name(view_name)
+ owner = self.denormalize_name(schema or self.default_schema_name)
+ query = (
+ select(dictionary.all_views.c.text)
+ .where(
+ dictionary.all_views.c.view_name == name,
+ dictionary.all_views.c.owner == owner,
+ )
+ .union_all(
+ select(dictionary.all_mviews.c.query).where(
+ dictionary.all_mviews.c.mview_name == name,
+ dictionary.all_mviews.c.owner == owner,
+ )
+ )
)
- params = {"view_name": view_name}
- text = "SELECT text FROM all_views WHERE view_name=:view_name"
-
- if schema is not None:
- text += " AND owner = :schema"
- params["schema"] = schema
-
- rp = connection.execute(sql.text(text), params).scalar()
- if rp:
- return rp
+ rp = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalar()
+ if rp is None:
+ raise exc.NoSuchTableError(
+ f"{schema}.{view_name}" if schema else view_name
+ )
else:
- return None
+ return rp
@reflection.cache
def get_check_constraints(
self, connection, table_name, schema=None, include_all=False, **kw
):
- resolve_synonyms = kw.get("oracle_resolve_synonyms", False)
- dblink = kw.get("dblink", "")
- info_cache = kw.get("info_cache")
-
- (table_name, schema, dblink, synonym) = self._prepare_reflection_args(
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ data = self.get_multi_check_constraints(
connection,
- table_name,
- schema,
- resolve_synonyms,
- dblink,
- info_cache=info_cache,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ include_all=include_all,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- constraint_data = self._get_constraint_data(
- connection,
- table_name,
- schema,
- dblink,
- info_cache=kw.get("info_cache"),
+ @_handle_synonyms_decorator
+ def get_multi_check_constraints(
+ self,
+ connection,
+ *,
+ schema,
+ filter_names,
+ dblink=None,
+ scope,
+ kind,
+ include_all=False,
+ **kw,
+ ):
+ """Supported kw arguments are: ``dblink`` to reflect via a db link;
+ ``oracle_resolve_synonyms`` to resolve names to synonyms
+ """
+ all_objects = self._get_all_objects(
+ connection, schema, scope, kind, filter_names, dblink, **kw
)
- check_constraints = filter(lambda x: x[1] == "C", constraint_data)
+ not_null = re.compile(r"..+?. IS NOT NULL$")
- return [
- {"name": self.normalize_name(cons[0]), "sqltext": cons[8]}
- for cons in check_constraints
- if include_all or not re.match(r"..+?. IS NOT NULL$", cons[8])
- ]
+ check_constraints = defaultdict(list)
+
+ for row_dict in self._get_all_constraint_rows(
+ connection, schema, dblink, all_objects, **kw
+ ):
+ if row_dict["constraint_type"] != "C":
+ continue
+ table_name = self.normalize_name(row_dict["table_name"])
+ constraint_name = self.normalize_name(row_dict["constraint_name"])
+ search_condition = row_dict["search_condition"]
+
+ table_checks = check_constraints[(schema, table_name)]
+ if constraint_name is not None and (
+ include_all or not not_null.match(search_condition)
+ ):
+ table_checks.append(
+ {"name": constraint_name, "sqltext": search_condition}
+ )
+
+ default = ReflectionDefaults.check_constraints
+
+ return (
+ (
+ key,
+ check_constraints[key]
+ if key in check_constraints
+ else default(),
+ )
+ for key in (
+ (schema, self.normalize_name(obj_name))
+ for obj_name in all_objects
+ )
+ )
+
+ def _list_dblinks(self, connection, dblink=None):
+ query = select(dictionary.all_db_links.c.db_link)
+ links = self._execute_reflection(
+ connection, query, dblink, returns_long=False
+ ).scalars()
+ return [self.normalize_name(link) for link in links]
class _OuterJoinColumn(sql.ClauseElement):
diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
index 25e93632c..d2ee0a96e 100644
--- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py
+++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py
@@ -431,6 +431,7 @@ from . import base as oracle
from .base import OracleCompiler
from .base import OracleDialect
from .base import OracleExecutionContext
+from .types import _OracleDateLiteralRender
from ... import exc
from ... import util
from ...engine import cursor as _cursor
@@ -573,7 +574,7 @@ class _CXOracleDate(oracle._OracleDate):
return process
-class _CXOracleTIMESTAMP(oracle._OracleDateLiteralRender, sqltypes.TIMESTAMP):
+class _CXOracleTIMESTAMP(_OracleDateLiteralRender, sqltypes.TIMESTAMP):
def literal_processor(self, dialect):
return self._literal_processor_datetime(dialect)
@@ -812,6 +813,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext):
return None
def pre_exec(self):
+ super().pre_exec()
if not getattr(self.compiled, "_oracle_cx_sql_compiler", False):
return
diff --git a/lib/sqlalchemy/dialects/oracle/dictionary.py b/lib/sqlalchemy/dialects/oracle/dictionary.py
new file mode 100644
index 000000000..ac7a350da
--- /dev/null
+++ b/lib/sqlalchemy/dialects/oracle/dictionary.py
@@ -0,0 +1,495 @@
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+from .types import DATE
+from .types import LONG
+from .types import NUMBER
+from .types import RAW
+from .types import VARCHAR2
+from ... import Column
+from ... import MetaData
+from ... import Table
+from ... import table
+from ...sql.sqltypes import CHAR
+
+# constants
+DB_LINK_PLACEHOLDER = "__$sa_dblink$__"
+# tables
+dual = table("dual")
+dictionary_meta = MetaData()
+
+# NOTE: all the dictionary_meta are aliases because oracle does not like
+# using the full table@dblink for every column in query, and complains with
+# ORA-00960: ambiguous column naming in select list
+all_tables = Table(
+ "all_tables" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("tablespace_name", VARCHAR2(30)),
+ Column("cluster_name", VARCHAR2(128)),
+ Column("iot_name", VARCHAR2(128)),
+ Column("status", VARCHAR2(8)),
+ Column("pct_free", NUMBER),
+ Column("pct_used", NUMBER),
+ Column("ini_trans", NUMBER),
+ Column("max_trans", NUMBER),
+ Column("initial_extent", NUMBER),
+ Column("next_extent", NUMBER),
+ Column("min_extents", NUMBER),
+ Column("max_extents", NUMBER),
+ Column("pct_increase", NUMBER),
+ Column("freelists", NUMBER),
+ Column("freelist_groups", NUMBER),
+ Column("logging", VARCHAR2(3)),
+ Column("backed_up", VARCHAR2(1)),
+ Column("num_rows", NUMBER),
+ Column("blocks", NUMBER),
+ Column("empty_blocks", NUMBER),
+ Column("avg_space", NUMBER),
+ Column("chain_cnt", NUMBER),
+ Column("avg_row_len", NUMBER),
+ Column("avg_space_freelist_blocks", NUMBER),
+ Column("num_freelist_blocks", NUMBER),
+ Column("degree", VARCHAR2(10)),
+ Column("instances", VARCHAR2(10)),
+ Column("cache", VARCHAR2(5)),
+ Column("table_lock", VARCHAR2(8)),
+ Column("sample_size", NUMBER),
+ Column("last_analyzed", DATE),
+ Column("partitioned", VARCHAR2(3)),
+ Column("iot_type", VARCHAR2(12)),
+ Column("temporary", VARCHAR2(1)),
+ Column("secondary", VARCHAR2(1)),
+ Column("nested", VARCHAR2(3)),
+ Column("buffer_pool", VARCHAR2(7)),
+ Column("flash_cache", VARCHAR2(7)),
+ Column("cell_flash_cache", VARCHAR2(7)),
+ Column("row_movement", VARCHAR2(8)),
+ Column("global_stats", VARCHAR2(3)),
+ Column("user_stats", VARCHAR2(3)),
+ Column("duration", VARCHAR2(15)),
+ Column("skip_corrupt", VARCHAR2(8)),
+ Column("monitoring", VARCHAR2(3)),
+ Column("cluster_owner", VARCHAR2(128)),
+ Column("dependencies", VARCHAR2(8)),
+ Column("compression", VARCHAR2(8)),
+ Column("compress_for", VARCHAR2(30)),
+ Column("dropped", VARCHAR2(3)),
+ Column("read_only", VARCHAR2(3)),
+ Column("segment_created", VARCHAR2(3)),
+ Column("result_cache", VARCHAR2(7)),
+ Column("clustering", VARCHAR2(3)),
+ Column("activity_tracking", VARCHAR2(23)),
+ Column("dml_timestamp", VARCHAR2(25)),
+ Column("has_identity", VARCHAR2(3)),
+ Column("container_data", VARCHAR2(3)),
+ Column("inmemory", VARCHAR2(8)),
+ Column("inmemory_priority", VARCHAR2(8)),
+ Column("inmemory_distribute", VARCHAR2(15)),
+ Column("inmemory_compression", VARCHAR2(17)),
+ Column("inmemory_duplicate", VARCHAR2(13)),
+ Column("default_collation", VARCHAR2(100)),
+ Column("duplicated", VARCHAR2(1)),
+ Column("sharded", VARCHAR2(1)),
+ Column("externally_sharded", VARCHAR2(1)),
+ Column("externally_duplicated", VARCHAR2(1)),
+ Column("external", VARCHAR2(3)),
+ Column("hybrid", VARCHAR2(3)),
+ Column("cellmemory", VARCHAR2(24)),
+ Column("containers_default", VARCHAR2(3)),
+ Column("container_map", VARCHAR2(3)),
+ Column("extended_data_link", VARCHAR2(3)),
+ Column("extended_data_link_map", VARCHAR2(3)),
+ Column("inmemory_service", VARCHAR2(12)),
+ Column("inmemory_service_name", VARCHAR2(1000)),
+ Column("container_map_object", VARCHAR2(3)),
+ Column("memoptimize_read", VARCHAR2(8)),
+ Column("memoptimize_write", VARCHAR2(8)),
+ Column("has_sensitive_column", VARCHAR2(3)),
+ Column("admit_null", VARCHAR2(3)),
+ Column("data_link_dml_enabled", VARCHAR2(3)),
+ Column("logical_replication", VARCHAR2(8)),
+).alias("a_tables")
+
+all_views = Table(
+ "all_views" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("view_name", VARCHAR2(128), nullable=False),
+ Column("text_length", NUMBER),
+ Column("text", LONG),
+ Column("text_vc", VARCHAR2(4000)),
+ Column("type_text_length", NUMBER),
+ Column("type_text", VARCHAR2(4000)),
+ Column("oid_text_length", NUMBER),
+ Column("oid_text", VARCHAR2(4000)),
+ Column("view_type_owner", VARCHAR2(128)),
+ Column("view_type", VARCHAR2(128)),
+ Column("superview_name", VARCHAR2(128)),
+ Column("editioning_view", VARCHAR2(1)),
+ Column("read_only", VARCHAR2(1)),
+ Column("container_data", VARCHAR2(1)),
+ Column("bequeath", VARCHAR2(12)),
+ Column("origin_con_id", VARCHAR2(256)),
+ Column("default_collation", VARCHAR2(100)),
+ Column("containers_default", VARCHAR2(3)),
+ Column("container_map", VARCHAR2(3)),
+ Column("extended_data_link", VARCHAR2(3)),
+ Column("extended_data_link_map", VARCHAR2(3)),
+ Column("has_sensitive_column", VARCHAR2(3)),
+ Column("admit_null", VARCHAR2(3)),
+ Column("pdb_local_only", VARCHAR2(3)),
+).alias("a_views")
+
+all_sequences = Table(
+ "all_sequences" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("sequence_owner", VARCHAR2(128), nullable=False),
+ Column("sequence_name", VARCHAR2(128), nullable=False),
+ Column("min_value", NUMBER),
+ Column("max_value", NUMBER),
+ Column("increment_by", NUMBER, nullable=False),
+ Column("cycle_flag", VARCHAR2(1)),
+ Column("order_flag", VARCHAR2(1)),
+ Column("cache_size", NUMBER, nullable=False),
+ Column("last_number", NUMBER, nullable=False),
+ Column("scale_flag", VARCHAR2(1)),
+ Column("extend_flag", VARCHAR2(1)),
+ Column("sharded_flag", VARCHAR2(1)),
+ Column("session_flag", VARCHAR2(1)),
+ Column("keep_value", VARCHAR2(1)),
+).alias("a_sequences")
+
+all_users = Table(
+ "all_users" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("username", VARCHAR2(128), nullable=False),
+ Column("user_id", NUMBER, nullable=False),
+ Column("created", DATE, nullable=False),
+ Column("common", VARCHAR2(3)),
+ Column("oracle_maintained", VARCHAR2(1)),
+ Column("inherited", VARCHAR2(3)),
+ Column("default_collation", VARCHAR2(100)),
+ Column("implicit", VARCHAR2(3)),
+ Column("all_shard", VARCHAR2(3)),
+ Column("external_shard", VARCHAR2(3)),
+).alias("a_users")
+
+all_mviews = Table(
+ "all_mviews" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("mview_name", VARCHAR2(128), nullable=False),
+ Column("container_name", VARCHAR2(128), nullable=False),
+ Column("query", LONG),
+ Column("query_len", NUMBER(38)),
+ Column("updatable", VARCHAR2(1)),
+ Column("update_log", VARCHAR2(128)),
+ Column("master_rollback_seg", VARCHAR2(128)),
+ Column("master_link", VARCHAR2(128)),
+ Column("rewrite_enabled", VARCHAR2(1)),
+ Column("rewrite_capability", VARCHAR2(9)),
+ Column("refresh_mode", VARCHAR2(6)),
+ Column("refresh_method", VARCHAR2(8)),
+ Column("build_mode", VARCHAR2(9)),
+ Column("fast_refreshable", VARCHAR2(18)),
+ Column("last_refresh_type", VARCHAR2(8)),
+ Column("last_refresh_date", DATE),
+ Column("last_refresh_end_time", DATE),
+ Column("staleness", VARCHAR2(19)),
+ Column("after_fast_refresh", VARCHAR2(19)),
+ Column("unknown_prebuilt", VARCHAR2(1)),
+ Column("unknown_plsql_func", VARCHAR2(1)),
+ Column("unknown_external_table", VARCHAR2(1)),
+ Column("unknown_consider_fresh", VARCHAR2(1)),
+ Column("unknown_import", VARCHAR2(1)),
+ Column("unknown_trusted_fd", VARCHAR2(1)),
+ Column("compile_state", VARCHAR2(19)),
+ Column("use_no_index", VARCHAR2(1)),
+ Column("stale_since", DATE),
+ Column("num_pct_tables", NUMBER),
+ Column("num_fresh_pct_regions", NUMBER),
+ Column("num_stale_pct_regions", NUMBER),
+ Column("segment_created", VARCHAR2(3)),
+ Column("evaluation_edition", VARCHAR2(128)),
+ Column("unusable_before", VARCHAR2(128)),
+ Column("unusable_beginning", VARCHAR2(128)),
+ Column("default_collation", VARCHAR2(100)),
+ Column("on_query_computation", VARCHAR2(1)),
+ Column("auto", VARCHAR2(3)),
+).alias("a_mviews")
+
+all_tab_identity_cols = Table(
+ "all_tab_identity_cols" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("column_name", VARCHAR2(128), nullable=False),
+ Column("generation_type", VARCHAR2(10)),
+ Column("sequence_name", VARCHAR2(128), nullable=False),
+ Column("identity_options", VARCHAR2(298)),
+).alias("a_tab_identity_cols")
+
+all_tab_cols = Table(
+ "all_tab_cols" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("column_name", VARCHAR2(128), nullable=False),
+ Column("data_type", VARCHAR2(128)),
+ Column("data_type_mod", VARCHAR2(3)),
+ Column("data_type_owner", VARCHAR2(128)),
+ Column("data_length", NUMBER, nullable=False),
+ Column("data_precision", NUMBER),
+ Column("data_scale", NUMBER),
+ Column("nullable", VARCHAR2(1)),
+ Column("column_id", NUMBER),
+ Column("default_length", NUMBER),
+ Column("data_default", LONG),
+ Column("num_distinct", NUMBER),
+ Column("low_value", RAW(1000)),
+ Column("high_value", RAW(1000)),
+ Column("density", NUMBER),
+ Column("num_nulls", NUMBER),
+ Column("num_buckets", NUMBER),
+ Column("last_analyzed", DATE),
+ Column("sample_size", NUMBER),
+ Column("character_set_name", VARCHAR2(44)),
+ Column("char_col_decl_length", NUMBER),
+ Column("global_stats", VARCHAR2(3)),
+ Column("user_stats", VARCHAR2(3)),
+ Column("avg_col_len", NUMBER),
+ Column("char_length", NUMBER),
+ Column("char_used", VARCHAR2(1)),
+ Column("v80_fmt_image", VARCHAR2(3)),
+ Column("data_upgraded", VARCHAR2(3)),
+ Column("hidden_column", VARCHAR2(3)),
+ Column("virtual_column", VARCHAR2(3)),
+ Column("segment_column_id", NUMBER),
+ Column("internal_column_id", NUMBER, nullable=False),
+ Column("histogram", VARCHAR2(15)),
+ Column("qualified_col_name", VARCHAR2(4000)),
+ Column("user_generated", VARCHAR2(3)),
+ Column("default_on_null", VARCHAR2(3)),
+ Column("identity_column", VARCHAR2(3)),
+ Column("evaluation_edition", VARCHAR2(128)),
+ Column("unusable_before", VARCHAR2(128)),
+ Column("unusable_beginning", VARCHAR2(128)),
+ Column("collation", VARCHAR2(100)),
+ Column("collated_column_id", NUMBER),
+).alias("a_tab_cols")
+
+all_tab_comments = Table(
+ "all_tab_comments" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("table_type", VARCHAR2(11)),
+ Column("comments", VARCHAR2(4000)),
+ Column("origin_con_id", NUMBER),
+).alias("a_tab_comments")
+
+all_col_comments = Table(
+ "all_col_comments" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("column_name", VARCHAR2(128), nullable=False),
+ Column("comments", VARCHAR2(4000)),
+ Column("origin_con_id", NUMBER),
+).alias("a_col_comments")
+
+all_mview_comments = Table(
+ "all_mview_comments" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("mview_name", VARCHAR2(128), nullable=False),
+ Column("comments", VARCHAR2(4000)),
+).alias("a_mview_comments")
+
+all_ind_columns = Table(
+ "all_ind_columns" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("index_owner", VARCHAR2(128), nullable=False),
+ Column("index_name", VARCHAR2(128), nullable=False),
+ Column("table_owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("column_name", VARCHAR2(4000)),
+ Column("column_position", NUMBER, nullable=False),
+ Column("column_length", NUMBER, nullable=False),
+ Column("char_length", NUMBER),
+ Column("descend", VARCHAR2(4)),
+ Column("collated_column_id", NUMBER),
+).alias("a_ind_columns")
+
+all_indexes = Table(
+ "all_indexes" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("index_name", VARCHAR2(128), nullable=False),
+ Column("index_type", VARCHAR2(27)),
+ Column("table_owner", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("table_type", CHAR(11)),
+ Column("uniqueness", VARCHAR2(9)),
+ Column("compression", VARCHAR2(13)),
+ Column("prefix_length", NUMBER),
+ Column("tablespace_name", VARCHAR2(30)),
+ Column("ini_trans", NUMBER),
+ Column("max_trans", NUMBER),
+ Column("initial_extent", NUMBER),
+ Column("next_extent", NUMBER),
+ Column("min_extents", NUMBER),
+ Column("max_extents", NUMBER),
+ Column("pct_increase", NUMBER),
+ Column("pct_threshold", NUMBER),
+ Column("include_column", NUMBER),
+ Column("freelists", NUMBER),
+ Column("freelist_groups", NUMBER),
+ Column("pct_free", NUMBER),
+ Column("logging", VARCHAR2(3)),
+ Column("blevel", NUMBER),
+ Column("leaf_blocks", NUMBER),
+ Column("distinct_keys", NUMBER),
+ Column("avg_leaf_blocks_per_key", NUMBER),
+ Column("avg_data_blocks_per_key", NUMBER),
+ Column("clustering_factor", NUMBER),
+ Column("status", VARCHAR2(8)),
+ Column("num_rows", NUMBER),
+ Column("sample_size", NUMBER),
+ Column("last_analyzed", DATE),
+ Column("degree", VARCHAR2(40)),
+ Column("instances", VARCHAR2(40)),
+ Column("partitioned", VARCHAR2(3)),
+ Column("temporary", VARCHAR2(1)),
+ Column("generated", VARCHAR2(1)),
+ Column("secondary", VARCHAR2(1)),
+ Column("buffer_pool", VARCHAR2(7)),
+ Column("flash_cache", VARCHAR2(7)),
+ Column("cell_flash_cache", VARCHAR2(7)),
+ Column("user_stats", VARCHAR2(3)),
+ Column("duration", VARCHAR2(15)),
+ Column("pct_direct_access", NUMBER),
+ Column("ityp_owner", VARCHAR2(128)),
+ Column("ityp_name", VARCHAR2(128)),
+ Column("parameters", VARCHAR2(1000)),
+ Column("global_stats", VARCHAR2(3)),
+ Column("domidx_status", VARCHAR2(12)),
+ Column("domidx_opstatus", VARCHAR2(6)),
+ Column("funcidx_status", VARCHAR2(8)),
+ Column("join_index", VARCHAR2(3)),
+ Column("iot_redundant_pkey_elim", VARCHAR2(3)),
+ Column("dropped", VARCHAR2(3)),
+ Column("visibility", VARCHAR2(9)),
+ Column("domidx_management", VARCHAR2(14)),
+ Column("segment_created", VARCHAR2(3)),
+ Column("orphaned_entries", VARCHAR2(3)),
+ Column("indexing", VARCHAR2(7)),
+ Column("auto", VARCHAR2(3)),
+).alias("a_indexes")
+
+all_constraints = Table(
+ "all_constraints" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128)),
+ Column("constraint_name", VARCHAR2(128)),
+ Column("constraint_type", VARCHAR2(1)),
+ Column("table_name", VARCHAR2(128)),
+ Column("search_condition", LONG),
+ Column("search_condition_vc", VARCHAR2(4000)),
+ Column("r_owner", VARCHAR2(128)),
+ Column("r_constraint_name", VARCHAR2(128)),
+ Column("delete_rule", VARCHAR2(9)),
+ Column("status", VARCHAR2(8)),
+ Column("deferrable", VARCHAR2(14)),
+ Column("deferred", VARCHAR2(9)),
+ Column("validated", VARCHAR2(13)),
+ Column("generated", VARCHAR2(14)),
+ Column("bad", VARCHAR2(3)),
+ Column("rely", VARCHAR2(4)),
+ Column("last_change", DATE),
+ Column("index_owner", VARCHAR2(128)),
+ Column("index_name", VARCHAR2(128)),
+ Column("invalid", VARCHAR2(7)),
+ Column("view_related", VARCHAR2(14)),
+ Column("origin_con_id", VARCHAR2(256)),
+).alias("a_constraints")
+
+all_cons_columns = Table(
+ "all_cons_columns" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("constraint_name", VARCHAR2(128), nullable=False),
+ Column("table_name", VARCHAR2(128), nullable=False),
+ Column("column_name", VARCHAR2(4000)),
+ Column("position", NUMBER),
+).alias("a_cons_columns")
+
+# TODO figure out if it's still relevant, since there is no mention from here
+# https://docs.oracle.com/en/database/oracle/oracle-database/21/refrn/ALL_DB_LINKS.html
+# original note:
+# using user_db_links here since all_db_links appears
+# to have more restricted permissions.
+# https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm
+# will need to hear from more users if we are doing
+# the right thing here. See [ticket:2619]
+all_db_links = Table(
+ "all_db_links" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("db_link", VARCHAR2(128), nullable=False),
+ Column("username", VARCHAR2(128)),
+ Column("host", VARCHAR2(2000)),
+ Column("created", DATE, nullable=False),
+ Column("hidden", VARCHAR2(3)),
+ Column("shard_internal", VARCHAR2(3)),
+ Column("valid", VARCHAR2(3)),
+ Column("intra_cdb", VARCHAR2(3)),
+).alias("a_db_links")
+
+all_synonyms = Table(
+ "all_synonyms" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128)),
+ Column("synonym_name", VARCHAR2(128)),
+ Column("table_owner", VARCHAR2(128)),
+ Column("table_name", VARCHAR2(128)),
+ Column("db_link", VARCHAR2(128)),
+ Column("origin_con_id", VARCHAR2(256)),
+).alias("a_synonyms")
+
+all_objects = Table(
+ "all_objects" + DB_LINK_PLACEHOLDER,
+ dictionary_meta,
+ Column("owner", VARCHAR2(128), nullable=False),
+ Column("object_name", VARCHAR2(128), nullable=False),
+ Column("subobject_name", VARCHAR2(128)),
+ Column("object_id", NUMBER, nullable=False),
+ Column("data_object_id", NUMBER),
+ Column("object_type", VARCHAR2(23)),
+ Column("created", DATE, nullable=False),
+ Column("last_ddl_time", DATE, nullable=False),
+ Column("timestamp", VARCHAR2(19)),
+ Column("status", VARCHAR2(7)),
+ Column("temporary", VARCHAR2(1)),
+ Column("generated", VARCHAR2(1)),
+ Column("secondary", VARCHAR2(1)),
+ Column("namespace", NUMBER, nullable=False),
+ Column("edition_name", VARCHAR2(128)),
+ Column("sharing", VARCHAR2(13)),
+ Column("editionable", VARCHAR2(1)),
+ Column("oracle_maintained", VARCHAR2(1)),
+ Column("application", VARCHAR2(1)),
+ Column("default_collation", VARCHAR2(100)),
+ Column("duplicated", VARCHAR2(1)),
+ Column("sharded", VARCHAR2(1)),
+ Column("created_appid", NUMBER),
+ Column("created_vsnid", NUMBER),
+ Column("modified_appid", NUMBER),
+ Column("modified_vsnid", NUMBER),
+).alias("a_objects")
diff --git a/lib/sqlalchemy/dialects/oracle/provision.py b/lib/sqlalchemy/dialects/oracle/provision.py
index cba3b5be4..75b7a7aa9 100644
--- a/lib/sqlalchemy/dialects/oracle/provision.py
+++ b/lib/sqlalchemy/dialects/oracle/provision.py
@@ -2,9 +2,12 @@
from ... import create_engine
from ... import exc
+from ... import inspect
from ...engine import url as sa_url
from ...testing.provision import configure_follower
from ...testing.provision import create_db
+from ...testing.provision import drop_all_schema_objects_post_tables
+from ...testing.provision import drop_all_schema_objects_pre_tables
from ...testing.provision import drop_db
from ...testing.provision import follower_url_from_main
from ...testing.provision import log
@@ -28,6 +31,10 @@ def _oracle_create_db(cfg, eng, ident):
conn.exec_driver_sql("grant unlimited tablespace to %s" % ident)
conn.exec_driver_sql("grant unlimited tablespace to %s_ts1" % ident)
conn.exec_driver_sql("grant unlimited tablespace to %s_ts2" % ident)
+ # these are needed to create materialized views
+ conn.exec_driver_sql("grant create table to %s" % ident)
+ conn.exec_driver_sql("grant create table to %s_ts1" % ident)
+ conn.exec_driver_sql("grant create table to %s_ts2" % ident)
@configure_follower.for_db("oracle")
@@ -46,6 +53,30 @@ def _ora_drop_ignore(conn, dbname):
return False
+@drop_all_schema_objects_pre_tables.for_db("oracle")
+def _ora_drop_all_schema_objects_pre_tables(cfg, eng):
+ _purge_recyclebin(eng)
+ _purge_recyclebin(eng, cfg.test_schema)
+
+
+@drop_all_schema_objects_post_tables.for_db("oracle")
+def _ora_drop_all_schema_objects_post_tables(cfg, eng):
+
+ with eng.begin() as conn:
+ for syn in conn.dialect._get_synonyms(conn, None, None, None):
+ conn.exec_driver_sql(f"drop synonym {syn['synonym_name']}")
+
+ for syn in conn.dialect._get_synonyms(
+ conn, cfg.test_schema, None, None
+ ):
+ conn.exec_driver_sql(
+ f"drop synonym {cfg.test_schema}.{syn['synonym_name']}"
+ )
+
+ for tmp_table in inspect(conn).get_temp_table_names():
+ conn.exec_driver_sql(f"drop table {tmp_table}")
+
+
@drop_db.for_db("oracle")
def _oracle_drop_db(cfg, eng, ident):
with eng.begin() as conn:
@@ -60,13 +91,10 @@ def _oracle_drop_db(cfg, eng, ident):
@stop_test_class_outside_fixtures.for_db("oracle")
-def stop_test_class_outside_fixtures(config, db, cls):
+def _ora_stop_test_class_outside_fixtures(config, db, cls):
try:
- with db.begin() as conn:
- # run magic command to get rid of identity sequences
- # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501
- conn.exec_driver_sql("purge recyclebin")
+ _purge_recyclebin(db)
except exc.DatabaseError as err:
log.warning("purge recyclebin command failed: %s", err)
@@ -85,6 +113,22 @@ def stop_test_class_outside_fixtures(config, db, cls):
_all_conns.clear()
+def _purge_recyclebin(eng, schema=None):
+ with eng.begin() as conn:
+ if schema is None:
+ # run magic command to get rid of identity sequences
+ # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501
+ conn.exec_driver_sql("purge recyclebin")
+ else:
+ # per user: https://community.oracle.com/tech/developers/discussion/2255402/how-to-clear-dba-recyclebin-for-a-particular-user # noqa: E501
+ for owner, object_name, type_ in conn.exec_driver_sql(
+ "select owner, object_name,type from "
+ "dba_recyclebin where owner=:schema and type='TABLE'",
+ {"schema": conn.dialect.denormalize_name(schema)},
+ ).all():
+ conn.exec_driver_sql(f'purge {type_} {owner}."{object_name}"')
+
+
_all_conns = set()
diff --git a/lib/sqlalchemy/dialects/oracle/types.py b/lib/sqlalchemy/dialects/oracle/types.py
new file mode 100644
index 000000000..60a8ebcb5
--- /dev/null
+++ b/lib/sqlalchemy/dialects/oracle/types.py
@@ -0,0 +1,233 @@
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+from ...sql import sqltypes
+from ...types import NVARCHAR
+from ...types import VARCHAR
+
+
+class RAW(sqltypes._Binary):
+ __visit_name__ = "RAW"
+
+
+OracleRaw = RAW
+
+
+class NCLOB(sqltypes.Text):
+ __visit_name__ = "NCLOB"
+
+
+class VARCHAR2(VARCHAR):
+ __visit_name__ = "VARCHAR2"
+
+
+NVARCHAR2 = NVARCHAR
+
+
+class NUMBER(sqltypes.Numeric, sqltypes.Integer):
+ __visit_name__ = "NUMBER"
+
+ def __init__(self, precision=None, scale=None, asdecimal=None):
+ if asdecimal is None:
+ asdecimal = bool(scale and scale > 0)
+
+ super(NUMBER, self).__init__(
+ precision=precision, scale=scale, asdecimal=asdecimal
+ )
+
+ def adapt(self, impltype):
+ ret = super(NUMBER, self).adapt(impltype)
+ # leave a hint for the DBAPI handler
+ ret._is_oracle_number = True
+ return ret
+
+ @property
+ def _type_affinity(self):
+ if bool(self.scale and self.scale > 0):
+ return sqltypes.Numeric
+ else:
+ return sqltypes.Integer
+
+
+class FLOAT(sqltypes.FLOAT):
+ """Oracle FLOAT.
+
+ This is the same as :class:`_sqltypes.FLOAT` except that
+ an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision`
+ parameter is accepted, and
+ the :paramref:`_sqltypes.Float.precision` parameter is not accepted.
+
+ Oracle FLOAT types indicate precision in terms of "binary precision", which
+ defaults to 126. For a REAL type, the value is 63. This parameter does not
+ cleanly map to a specific number of decimal places but is roughly
+ equivalent to the desired number of decimal places divided by 0.3103.
+
+ .. versionadded:: 2.0
+
+ """
+
+ __visit_name__ = "FLOAT"
+
+ def __init__(
+ self,
+ binary_precision=None,
+ asdecimal=False,
+ decimal_return_scale=None,
+ ):
+ r"""
+ Construct a FLOAT
+
+ :param binary_precision: Oracle binary precision value to be rendered
+ in DDL. This may be approximated to the number of decimal characters
+ using the formula "decimal precision = 0.30103 * binary precision".
+ The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126.
+
+ :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal`
+
+ :param decimal_return_scale: See
+ :paramref:`_sqltypes.Float.decimal_return_scale`
+
+ """
+ super().__init__(
+ asdecimal=asdecimal, decimal_return_scale=decimal_return_scale
+ )
+ self.binary_precision = binary_precision
+
+
+class BINARY_DOUBLE(sqltypes.Float):
+ __visit_name__ = "BINARY_DOUBLE"
+
+
+class BINARY_FLOAT(sqltypes.Float):
+ __visit_name__ = "BINARY_FLOAT"
+
+
+class BFILE(sqltypes.LargeBinary):
+ __visit_name__ = "BFILE"
+
+
+class LONG(sqltypes.Text):
+ __visit_name__ = "LONG"
+
+
+class _OracleDateLiteralRender:
+ def _literal_processor_datetime(self, dialect):
+ def process(value):
+ if value is not None:
+ if getattr(value, "microsecond", None):
+ value = (
+ f"""TO_TIMESTAMP"""
+ f"""('{value.isoformat().replace("T", " ")}', """
+ """'YYYY-MM-DD HH24:MI:SS.FF')"""
+ )
+ else:
+ value = (
+ f"""TO_DATE"""
+ f"""('{value.isoformat().replace("T", " ")}', """
+ """'YYYY-MM-DD HH24:MI:SS')"""
+ )
+ return value
+
+ return process
+
+ def _literal_processor_date(self, dialect):
+ def process(value):
+ if value is not None:
+ if getattr(value, "microsecond", None):
+ value = (
+ f"""TO_TIMESTAMP"""
+ f"""('{value.isoformat().split("T")[0]}', """
+ """'YYYY-MM-DD')"""
+ )
+ else:
+ value = (
+ f"""TO_DATE"""
+ f"""('{value.isoformat().split("T")[0]}', """
+ """'YYYY-MM-DD')"""
+ )
+ return value
+
+ return process
+
+
+class DATE(_OracleDateLiteralRender, sqltypes.DateTime):
+ """Provide the oracle DATE type.
+
+ This type has no special Python behavior, except that it subclasses
+ :class:`_types.DateTime`; this is to suit the fact that the Oracle
+ ``DATE`` type supports a time value.
+
+ .. versionadded:: 0.9.4
+
+ """
+
+ __visit_name__ = "DATE"
+
+ def literal_processor(self, dialect):
+ return self._literal_processor_datetime(dialect)
+
+ def _compare_type_affinity(self, other):
+ return other._type_affinity in (sqltypes.DateTime, sqltypes.Date)
+
+
+class _OracleDate(_OracleDateLiteralRender, sqltypes.Date):
+ def literal_processor(self, dialect):
+ return self._literal_processor_date(dialect)
+
+
+class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
+ __visit_name__ = "INTERVAL"
+
+ def __init__(self, day_precision=None, second_precision=None):
+ """Construct an INTERVAL.
+
+ Note that only DAY TO SECOND intervals are currently supported.
+ This is due to a lack of support for YEAR TO MONTH intervals
+ within available DBAPIs.
+
+ :param day_precision: the day precision value. this is the number of
+ digits to store for the day field. Defaults to "2"
+ :param second_precision: the second precision value. this is the
+ number of digits to store for the fractional seconds field.
+ Defaults to "6".
+
+ """
+ self.day_precision = day_precision
+ self.second_precision = second_precision
+
+ @classmethod
+ def _adapt_from_generic_interval(cls, interval):
+ return INTERVAL(
+ day_precision=interval.day_precision,
+ second_precision=interval.second_precision,
+ )
+
+ @property
+ def _type_affinity(self):
+ return sqltypes.Interval
+
+ def as_generic(self, allow_nulltype=False):
+ return sqltypes.Interval(
+ native=True,
+ second_precision=self.second_precision,
+ day_precision=self.day_precision,
+ )
+
+
+class ROWID(sqltypes.TypeEngine):
+ """Oracle ROWID type.
+
+ When used in a cast() or similar, generates ROWID.
+
+ """
+
+ __visit_name__ = "ROWID"
+
+
+class _OracleBoolean(sqltypes.Boolean):
+ def get_dbapi_type(self, dbapi):
+ return dbapi.NUMBER
diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py
index c2472fb55..85bbf8c5b 100644
--- a/lib/sqlalchemy/dialects/postgresql/__init__.py
+++ b/lib/sqlalchemy/dialects/postgresql/__init__.py
@@ -19,31 +19,16 @@ from .array import Any
from .array import ARRAY
from .array import array
from .base import BIGINT
-from .base import BIT
from .base import BOOLEAN
-from .base import BYTEA
from .base import CHAR
-from .base import CIDR
-from .base import CreateEnumType
from .base import DATE
from .base import DOUBLE_PRECISION
-from .base import DropEnumType
-from .base import ENUM
from .base import FLOAT
-from .base import INET
from .base import INTEGER
-from .base import INTERVAL
-from .base import MACADDR
-from .base import MONEY
from .base import NUMERIC
-from .base import OID
from .base import REAL
-from .base import REGCLASS
from .base import SMALLINT
from .base import TEXT
-from .base import TIME
-from .base import TIMESTAMP
-from .base import TSVECTOR
from .base import UUID
from .base import VARCHAR
from .dml import Insert
@@ -61,7 +46,21 @@ from .ranges import INT8RANGE
from .ranges import NUMRANGE
from .ranges import TSRANGE
from .ranges import TSTZRANGE
-from ...util import compat
+from .types import BIT
+from .types import BYTEA
+from .types import CIDR
+from .types import CreateEnumType
+from .types import DropEnumType
+from .types import ENUM
+from .types import INET
+from .types import INTERVAL
+from .types import MACADDR
+from .types import MONEY
+from .types import OID
+from .types import REGCLASS
+from .types import TIME
+from .types import TIMESTAMP
+from .types import TSVECTOR
# Alias psycopg also as psycopg_async
psycopg_async = type(
diff --git a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py
index e831f2ed9..8dcd36c6d 100644
--- a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py
+++ b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py
@@ -1,3 +1,8 @@
+# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors
import decimal
@@ -9,6 +14,9 @@ from .base import _INT_TYPES
from .base import PGDialect
from .base import PGExecutionContext
from .hstore import HSTORE
+from .pg_catalog import _SpaceVector
+from .pg_catalog import INT2VECTOR
+from .pg_catalog import OIDVECTOR
from ... import exc
from ... import types as sqltypes
from ... import util
@@ -66,6 +74,14 @@ class _PsycopgARRAY(PGARRAY):
render_bind_cast = True
+class _PsycopgINT2VECTOR(_SpaceVector, INT2VECTOR):
+ pass
+
+
+class _PsycopgOIDVECTOR(_SpaceVector, OIDVECTOR):
+ pass
+
+
class _PGExecutionContext_common_psycopg(PGExecutionContext):
def create_server_side_cursor(self):
# use server-side cursors:
@@ -91,6 +107,8 @@ class _PGDialect_common_psycopg(PGDialect):
sqltypes.Numeric: _PsycopgNumeric,
HSTORE: _PsycopgHStore,
sqltypes.ARRAY: _PsycopgARRAY,
+ INT2VECTOR: _PsycopgINT2VECTOR,
+ OIDVECTOR: _PsycopgOIDVECTOR,
},
)
diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
index 1ec787e1f..d6385a5d6 100644
--- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py
+++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py
@@ -274,6 +274,10 @@ class AsyncpgOID(OID):
render_bind_cast = True
+class AsyncpgCHAR(sqltypes.CHAR):
+ render_bind_cast = True
+
+
class PGExecutionContext_asyncpg(PGExecutionContext):
def handle_dbapi_exception(self, e):
if isinstance(
@@ -823,6 +827,7 @@ class PGDialect_asyncpg(PGDialect):
sqltypes.Enum: AsyncPgEnum,
OID: AsyncpgOID,
REGCLASS: AsyncpgREGCLASS,
+ sqltypes.CHAR: AsyncpgCHAR,
},
)
is_async = True
diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py
index 83e46151f..36de76e0d 100644
--- a/lib/sqlalchemy/dialects/postgresql/base.py
+++ b/lib/sqlalchemy/dialects/postgresql/base.py
@@ -11,7 +11,7 @@ r"""
:name: PostgreSQL
:full_support: 9.6, 10, 11, 12, 13, 14
:normal_support: 9.6+
- :best_effort: 8+
+ :best_effort: 9+
.. _postgresql_sequences:
@@ -1448,23 +1448,52 @@ E.g.::
from __future__ import annotations
from collections import defaultdict
-import datetime as dt
+from functools import lru_cache
import re
-from typing import Any
from . import array as _array
from . import dml
from . import hstore as _hstore
from . import json as _json
+from . import pg_catalog
from . import ranges as _ranges
+from .types import _DECIMAL_TYPES # noqa
+from .types import _FLOAT_TYPES # noqa
+from .types import _INT_TYPES # noqa
+from .types import BIT
+from .types import BYTEA
+from .types import CIDR
+from .types import CreateEnumType # noqa
+from .types import DropEnumType # noqa
+from .types import ENUM
+from .types import INET
+from .types import INTERVAL
+from .types import MACADDR
+from .types import MONEY
+from .types import OID
+from .types import PGBit # noqa
+from .types import PGCidr # noqa
+from .types import PGInet # noqa
+from .types import PGInterval # noqa
+from .types import PGMacAddr # noqa
+from .types import PGUuid
+from .types import REGCLASS
+from .types import TIME
+from .types import TIMESTAMP
+from .types import TSVECTOR
from ... import exc
from ... import schema
+from ... import select
from ... import sql
from ... import util
from ...engine import characteristics
from ...engine import default
from ...engine import interfaces
+from ...engine import ObjectKind
+from ...engine import ObjectScope
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
+from ...sql import bindparam
from ...sql import coercions
from ...sql import compiler
from ...sql import elements
@@ -1472,7 +1501,7 @@ from ...sql import expression
from ...sql import roles
from ...sql import sqltypes
from ...sql import util as sql_util
-from ...sql.ddl import InvokeDDLBase
+from ...sql.visitors import InternalTraversal
from ...types import BIGINT
from ...types import BOOLEAN
from ...types import CHAR
@@ -1596,469 +1625,6 @@ RESERVED_WORDS = set(
]
)
-_DECIMAL_TYPES = (1231, 1700)
-_FLOAT_TYPES = (700, 701, 1021, 1022)
-_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016)
-
-
-class PGUuid(UUID):
- render_bind_cast = True
- render_literal_cast = True
-
-
-class BYTEA(sqltypes.LargeBinary[bytes]):
- __visit_name__ = "BYTEA"
-
-
-class INET(sqltypes.TypeEngine[str]):
- __visit_name__ = "INET"
-
-
-PGInet = INET
-
-
-class CIDR(sqltypes.TypeEngine[str]):
- __visit_name__ = "CIDR"
-
-
-PGCidr = CIDR
-
-
-class MACADDR(sqltypes.TypeEngine[str]):
- __visit_name__ = "MACADDR"
-
-
-PGMacAddr = MACADDR
-
-
-class MONEY(sqltypes.TypeEngine[str]):
-
- r"""Provide the PostgreSQL MONEY type.
-
- Depending on driver, result rows using this type may return a
- string value which includes currency symbols.
-
- For this reason, it may be preferable to provide conversion to a
- numerically-based currency datatype using :class:`_types.TypeDecorator`::
-
- import re
- import decimal
- from sqlalchemy import TypeDecorator
-
- class NumericMoney(TypeDecorator):
- impl = MONEY
-
- def process_result_value(self, value: Any, dialect: Any) -> None:
- if value is not None:
- # adjust this for the currency and numeric
- m = re.match(r"\$([\d.]+)", value)
- if m:
- value = decimal.Decimal(m.group(1))
- return value
-
- Alternatively, the conversion may be applied as a CAST using
- the :meth:`_types.TypeDecorator.column_expression` method as follows::
-
- import decimal
- from sqlalchemy import cast
- from sqlalchemy import TypeDecorator
-
- class NumericMoney(TypeDecorator):
- impl = MONEY
-
- def column_expression(self, column: Any):
- return cast(column, Numeric())
-
- .. versionadded:: 1.2
-
- """
-
- __visit_name__ = "MONEY"
-
-
-class OID(sqltypes.TypeEngine[int]):
-
- """Provide the PostgreSQL OID type.
-
- .. versionadded:: 0.9.5
-
- """
-
- __visit_name__ = "OID"
-
-
-class REGCLASS(sqltypes.TypeEngine[str]):
-
- """Provide the PostgreSQL REGCLASS type.
-
- .. versionadded:: 1.2.7
-
- """
-
- __visit_name__ = "REGCLASS"
-
-
-class TIMESTAMP(sqltypes.TIMESTAMP):
- def __init__(self, timezone=False, precision=None):
- super(TIMESTAMP, self).__init__(timezone=timezone)
- self.precision = precision
-
-
-class TIME(sqltypes.TIME):
- def __init__(self, timezone=False, precision=None):
- super(TIME, self).__init__(timezone=timezone)
- self.precision = precision
-
-
-class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
-
- """PostgreSQL INTERVAL type."""
-
- __visit_name__ = "INTERVAL"
- native = True
-
- def __init__(self, precision=None, fields=None):
- """Construct an INTERVAL.
-
- :param precision: optional integer precision value
- :param fields: string fields specifier. allows storage of fields
- to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``,
- etc.
-
- .. versionadded:: 1.2
-
- """
- self.precision = precision
- self.fields = fields
-
- @classmethod
- def adapt_emulated_to_native(cls, interval, **kw):
- return INTERVAL(precision=interval.second_precision)
-
- @property
- def _type_affinity(self):
- return sqltypes.Interval
-
- def as_generic(self, allow_nulltype=False):
- return sqltypes.Interval(native=True, second_precision=self.precision)
-
- @property
- def python_type(self):
- return dt.timedelta
-
-
-PGInterval = INTERVAL
-
-
-class BIT(sqltypes.TypeEngine[int]):
- __visit_name__ = "BIT"
-
- def __init__(self, length=None, varying=False):
- if not varying:
- # BIT without VARYING defaults to length 1
- self.length = length or 1
- else:
- # but BIT VARYING can be unlimited-length, so no default
- self.length = length
- self.varying = varying
-
-
-PGBit = BIT
-
-
-class TSVECTOR(sqltypes.TypeEngine[Any]):
-
- """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
- text search type TSVECTOR.
-
- It can be used to do full text queries on natural language
- documents.
-
- .. versionadded:: 0.9.0
-
- .. seealso::
-
- :ref:`postgresql_match`
-
- """
-
- __visit_name__ = "TSVECTOR"
-
-
-class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
-
- """PostgreSQL ENUM type.
-
- This is a subclass of :class:`_types.Enum` which includes
- support for PG's ``CREATE TYPE`` and ``DROP TYPE``.
-
- When the builtin type :class:`_types.Enum` is used and the
- :paramref:`.Enum.native_enum` flag is left at its default of
- True, the PostgreSQL backend will use a :class:`_postgresql.ENUM`
- type as the implementation, so the special create/drop rules
- will be used.
-
- The create/drop behavior of ENUM is necessarily intricate, due to the
- awkward relationship the ENUM type has in relationship to the
- parent table, in that it may be "owned" by just a single table, or
- may be shared among many tables.
-
- When using :class:`_types.Enum` or :class:`_postgresql.ENUM`
- in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted
- corresponding to when the :meth:`_schema.Table.create` and
- :meth:`_schema.Table.drop`
- methods are called::
-
- table = Table('sometable', metadata,
- Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
- )
-
- table.create(engine) # will emit CREATE ENUM and CREATE TABLE
- table.drop(engine) # will emit DROP TABLE and DROP ENUM
-
- To use a common enumerated type between multiple tables, the best
- practice is to declare the :class:`_types.Enum` or
- :class:`_postgresql.ENUM` independently, and associate it with the
- :class:`_schema.MetaData` object itself::
-
- my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
-
- t1 = Table('sometable_one', metadata,
- Column('some_enum', myenum)
- )
-
- t2 = Table('sometable_two', metadata,
- Column('some_enum', myenum)
- )
-
- When this pattern is used, care must still be taken at the level
- of individual table creates. Emitting CREATE TABLE without also
- specifying ``checkfirst=True`` will still cause issues::
-
- t1.create(engine) # will fail: no such type 'myenum'
-
- If we specify ``checkfirst=True``, the individual table-level create
- operation will check for the ``ENUM`` and create if not exists::
-
- # will check if enum exists, and emit CREATE TYPE if not
- t1.create(engine, checkfirst=True)
-
- When using a metadata-level ENUM type, the type will always be created
- and dropped if either the metadata-wide create/drop is called::
-
- metadata.create_all(engine) # will emit CREATE TYPE
- metadata.drop_all(engine) # will emit DROP TYPE
-
- The type can also be created and dropped directly::
-
- my_enum.create(engine)
- my_enum.drop(engine)
-
- .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type
- now behaves more strictly with regards to CREATE/DROP. A metadata-level
- ENUM type will only be created and dropped at the metadata level,
- not the table level, with the exception of
- ``table.create(checkfirst=True)``.
- The ``table.drop()`` call will now emit a DROP TYPE for a table-level
- enumerated type.
-
- """
-
- native_enum = True
-
- def __init__(self, *enums, **kw):
- """Construct an :class:`_postgresql.ENUM`.
-
- Arguments are the same as that of
- :class:`_types.Enum`, but also including
- the following parameters.
-
- :param create_type: Defaults to True.
- Indicates that ``CREATE TYPE`` should be
- emitted, after optionally checking for the
- presence of the type, when the parent
- table is being created; and additionally
- that ``DROP TYPE`` is called when the table
- is dropped. When ``False``, no check
- will be performed and no ``CREATE TYPE``
- or ``DROP TYPE`` is emitted, unless
- :meth:`~.postgresql.ENUM.create`
- or :meth:`~.postgresql.ENUM.drop`
- are called directly.
- Setting to ``False`` is helpful
- when invoking a creation scheme to a SQL file
- without access to the actual database -
- the :meth:`~.postgresql.ENUM.create` and
- :meth:`~.postgresql.ENUM.drop` methods can
- be used to emit SQL to a target bind.
-
- """
- native_enum = kw.pop("native_enum", None)
- if native_enum is False:
- util.warn(
- "the native_enum flag does not apply to the "
- "sqlalchemy.dialects.postgresql.ENUM datatype; this type "
- "always refers to ENUM. Use sqlalchemy.types.Enum for "
- "non-native enum."
- )
- self.create_type = kw.pop("create_type", True)
- super(ENUM, self).__init__(*enums, **kw)
-
- @classmethod
- def adapt_emulated_to_native(cls, impl, **kw):
- """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain
- :class:`.Enum`.
-
- """
- kw.setdefault("validate_strings", impl.validate_strings)
- kw.setdefault("name", impl.name)
- kw.setdefault("schema", impl.schema)
- kw.setdefault("inherit_schema", impl.inherit_schema)
- kw.setdefault("metadata", impl.metadata)
- kw.setdefault("_create_events", False)
- kw.setdefault("values_callable", impl.values_callable)
- kw.setdefault("omit_aliases", impl._omit_aliases)
- return cls(**kw)
-
- def create(self, bind=None, checkfirst=True):
- """Emit ``CREATE TYPE`` for this
- :class:`_postgresql.ENUM`.
-
- If the underlying dialect does not support
- PostgreSQL CREATE TYPE, no action is taken.
-
- :param bind: a connectable :class:`_engine.Engine`,
- :class:`_engine.Connection`, or similar object to emit
- SQL.
- :param checkfirst: if ``True``, a query against
- the PG catalog will be first performed to see
- if the type does not exist already before
- creating.
-
- """
- if not bind.dialect.supports_native_enum:
- return
-
- bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst)
-
- def drop(self, bind=None, checkfirst=True):
- """Emit ``DROP TYPE`` for this
- :class:`_postgresql.ENUM`.
-
- If the underlying dialect does not support
- PostgreSQL DROP TYPE, no action is taken.
-
- :param bind: a connectable :class:`_engine.Engine`,
- :class:`_engine.Connection`, or similar object to emit
- SQL.
- :param checkfirst: if ``True``, a query against
- the PG catalog will be first performed to see
- if the type actually exists before dropping.
-
- """
- if not bind.dialect.supports_native_enum:
- return
-
- bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst)
-
- class EnumGenerator(InvokeDDLBase):
- def __init__(self, dialect, connection, checkfirst=False, **kwargs):
- super(ENUM.EnumGenerator, self).__init__(connection, **kwargs)
- self.checkfirst = checkfirst
-
- def _can_create_enum(self, enum):
- if not self.checkfirst:
- return True
-
- effective_schema = self.connection.schema_for_object(enum)
-
- return not self.connection.dialect.has_type(
- self.connection, enum.name, schema=effective_schema
- )
-
- def visit_enum(self, enum):
- if not self._can_create_enum(enum):
- return
-
- self.connection.execute(CreateEnumType(enum))
-
- class EnumDropper(InvokeDDLBase):
- def __init__(self, dialect, connection, checkfirst=False, **kwargs):
- super(ENUM.EnumDropper, self).__init__(connection, **kwargs)
- self.checkfirst = checkfirst
-
- def _can_drop_enum(self, enum):
- if not self.checkfirst:
- return True
-
- effective_schema = self.connection.schema_for_object(enum)
-
- return self.connection.dialect.has_type(
- self.connection, enum.name, schema=effective_schema
- )
-
- def visit_enum(self, enum):
- if not self._can_drop_enum(enum):
- return
-
- self.connection.execute(DropEnumType(enum))
-
- def get_dbapi_type(self, dbapi):
- """dont return dbapi.STRING for ENUM in PostgreSQL, since that's
- a different type"""
-
- return None
-
- def _check_for_name_in_memos(self, checkfirst, kw):
- """Look in the 'ddl runner' for 'memos', then
- note our name in that collection.
-
- This to ensure a particular named enum is operated
- upon only once within any kind of create/drop
- sequence without relying upon "checkfirst".
-
- """
- if not self.create_type:
- return True
- if "_ddl_runner" in kw:
- ddl_runner = kw["_ddl_runner"]
- if "_pg_enums" in ddl_runner.memo:
- pg_enums = ddl_runner.memo["_pg_enums"]
- else:
- pg_enums = ddl_runner.memo["_pg_enums"] = set()
- present = (self.schema, self.name) in pg_enums
- pg_enums.add((self.schema, self.name))
- return present
- else:
- return False
-
- def _on_table_create(self, target, bind, checkfirst=False, **kw):
- if (
- checkfirst
- or (
- not self.metadata
- and not kw.get("_is_metadata_operation", False)
- )
- ) and not self._check_for_name_in_memos(checkfirst, kw):
- self.create(bind=bind, checkfirst=checkfirst)
-
- def _on_table_drop(self, target, bind, checkfirst=False, **kw):
- if (
- not self.metadata
- and not kw.get("_is_metadata_operation", False)
- and not self._check_for_name_in_memos(checkfirst, kw)
- ):
- self.drop(bind=bind, checkfirst=checkfirst)
-
- def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
- if not self._check_for_name_in_memos(checkfirst, kw):
- self.create(bind=bind, checkfirst=checkfirst)
-
- def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
- if not self._check_for_name_in_memos(checkfirst, kw):
- self.drop(bind=bind, checkfirst=checkfirst)
-
-
colspecs = {
sqltypes.ARRAY: _array.ARRAY,
sqltypes.Interval: INTERVAL,
@@ -2997,8 +2563,19 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer):
class PGInspector(reflection.Inspector):
+ dialect: PGDialect
+
def get_table_oid(self, table_name, schema=None):
- """Return the OID for the given table name."""
+ """Return the OID for the given table name.
+
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ """
with self._operation_context() as conn:
return self.dialect.get_table_oid(
@@ -3023,9 +2600,10 @@ class PGInspector(reflection.Inspector):
.. versionadded:: 1.0.0
"""
- schema = schema or self.default_schema_name
with self._operation_context() as conn:
- return self.dialect._load_enums(conn, schema)
+ return self.dialect._load_enums(
+ conn, schema, info_cache=self.info_cache
+ )
def get_foreign_table_names(self, schema=None):
"""Return a list of FOREIGN TABLE names.
@@ -3038,38 +2616,29 @@ class PGInspector(reflection.Inspector):
.. versionadded:: 1.0.0
"""
- schema = schema or self.default_schema_name
with self._operation_context() as conn:
- return self.dialect._get_foreign_table_names(conn, schema)
-
- def get_view_names(self, schema=None, include=("plain", "materialized")):
- """Return all view names in `schema`.
+ return self.dialect._get_foreign_table_names(
+ conn, schema, info_cache=self.info_cache
+ )
- :param schema: Optional, retrieve names from a non-default schema.
- For special quoting, use :class:`.quoted_name`.
+ def has_type(self, type_name, schema=None, **kw):
+ """Return if the database has the specified type in the provided
+ schema.
- :param include: specify which types of views to return. Passed
- as a string value (for a single type) or a tuple (for any number
- of types). Defaults to ``('plain', 'materialized')``.
+ :param type_name: the type to check.
+ :param schema: schema name. If None, the default schema
+ (typically 'public') is used. May also be set to '*' to
+ check in all schemas.
- .. versionadded:: 1.1
+ .. versionadded:: 2.0
"""
-
with self._operation_context() as conn:
- return self.dialect.get_view_names(
- conn, schema, info_cache=self.info_cache, include=include
+ return self.dialect.has_type(
+ conn, type_name, schema, info_cache=self.info_cache
)
-class CreateEnumType(schema._CreateDropBase):
- __visit_name__ = "create_enum_type"
-
-
-class DropEnumType(schema._CreateDropBase):
- __visit_name__ = "drop_enum_type"
-
-
class PGExecutionContext(default.DefaultExecutionContext):
def fire_sequence(self, seq, type_):
return self._execute_scalar(
@@ -3262,35 +2831,14 @@ class PGDialect(default.DefaultDialect):
def initialize(self, connection):
super(PGDialect, self).initialize(connection)
- if self.server_version_info <= (8, 2):
- self.delete_returning = (
- self.update_returning
- ) = self.insert_returning = False
-
- self.supports_native_enum = self.server_version_info >= (8, 3)
- if not self.supports_native_enum:
- self.colspecs = self.colspecs.copy()
- # pop base Enum type
- self.colspecs.pop(sqltypes.Enum, None)
- # psycopg2, others may have placed ENUM here as well
- self.colspecs.pop(ENUM, None)
-
# https://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689
self.supports_smallserial = self.server_version_info >= (9, 2)
- if self.server_version_info < (8, 2):
- self._backslash_escapes = False
- else:
- # ensure this query is not emitted on server version < 8.2
- # as it will fail
- std_string = connection.exec_driver_sql(
- "show standard_conforming_strings"
- ).scalar()
- self._backslash_escapes = std_string == "off"
-
- self._supports_create_index_concurrently = (
- self.server_version_info >= (8, 2)
- )
+ std_string = connection.exec_driver_sql(
+ "show standard_conforming_strings"
+ ).scalar()
+ self._backslash_escapes = std_string == "off"
+
self._supports_drop_index_concurrently = self.server_version_info >= (
9,
2,
@@ -3370,122 +2918,100 @@ class PGDialect(default.DefaultDialect):
self.do_commit(connection.connection)
def do_recover_twophase(self, connection):
- resultset = connection.execute(
+ return connection.scalars(
sql.text("SELECT gid FROM pg_prepared_xacts")
- )
- return [row[0] for row in resultset]
+ ).all()
def _get_default_schema_name(self, connection):
return connection.exec_driver_sql("select current_schema()").scalar()
- def has_schema(self, connection, schema):
- query = (
- "select nspname from pg_namespace " "where lower(nspname)=:schema"
- )
- cursor = connection.execute(
- sql.text(query).bindparams(
- sql.bindparam(
- "schema",
- str(schema.lower()),
- type_=sqltypes.Unicode,
- )
- )
+ @reflection.cache
+ def has_schema(self, connection, schema, **kw):
+ query = select(pg_catalog.pg_namespace.c.nspname).where(
+ pg_catalog.pg_namespace.c.nspname == schema
)
+ return bool(connection.scalar(query))
- return bool(cursor.first())
-
- def has_table(self, connection, table_name, schema=None):
- self._ensure_has_table_connection(connection)
- # seems like case gets folded in pg_class...
+ def _pg_class_filter_scope_schema(
+ self, query, schema, scope, pg_class_table=None
+ ):
+ if pg_class_table is None:
+ pg_class_table = pg_catalog.pg_class
+ query = query.join(
+ pg_catalog.pg_namespace,
+ pg_catalog.pg_namespace.c.oid == pg_class_table.c.relnamespace,
+ )
+ if scope is ObjectScope.DEFAULT:
+ query = query.where(pg_class_table.c.relpersistence != "t")
+ elif scope is ObjectScope.TEMPORARY:
+ query = query.where(pg_class_table.c.relpersistence == "t")
if schema is None:
- cursor = connection.execute(
- sql.text(
- "select relname from pg_class c join pg_namespace n on "
- "n.oid=c.relnamespace where "
- "pg_catalog.pg_table_is_visible(c.oid) "
- "and relname=:name"
- ).bindparams(
- sql.bindparam(
- "name",
- str(table_name),
- type_=sqltypes.Unicode,
- )
- )
+ query = query.where(
+ pg_catalog.pg_table_is_visible(pg_class_table.c.oid),
+ # ignore pg_catalog schema
+ pg_catalog.pg_namespace.c.nspname != "pg_catalog",
)
else:
- cursor = connection.execute(
- sql.text(
- "select relname from pg_class c join pg_namespace n on "
- "n.oid=c.relnamespace where n.nspname=:schema and "
- "relname=:name"
- ).bindparams(
- sql.bindparam(
- "name",
- str(table_name),
- type_=sqltypes.Unicode,
- ),
- sql.bindparam(
- "schema",
- str(schema),
- type_=sqltypes.Unicode,
- ),
- )
- )
- return bool(cursor.first())
-
- def has_sequence(self, connection, sequence_name, schema=None):
- if schema is None:
- schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- "SELECT relname FROM pg_class c join pg_namespace n on "
- "n.oid=c.relnamespace where relkind='S' and "
- "n.nspname=:schema and relname=:name"
- ).bindparams(
- sql.bindparam(
- "name",
- str(sequence_name),
- type_=sqltypes.Unicode,
- ),
- sql.bindparam(
- "schema",
- str(schema),
- type_=sqltypes.Unicode,
- ),
- )
+ query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
+ return query
+
+ def _pg_class_relkind_condition(self, relkinds, pg_class_table=None):
+ if pg_class_table is None:
+ pg_class_table = pg_catalog.pg_class
+ # uses the any form instead of in otherwise postgresql complaings
+ # that 'IN could not convert type character to "char"'
+ return pg_class_table.c.relkind == sql.any_(_array.array(relkinds))
+
+ @lru_cache()
+ def _has_table_query(self, schema):
+ query = select(pg_catalog.pg_class.c.relname).where(
+ pg_catalog.pg_class.c.relname == bindparam("table_name"),
+ self._pg_class_relkind_condition(
+ pg_catalog.RELKINDS_ALL_TABLE_LIKE
+ ),
+ )
+ return self._pg_class_filter_scope_schema(
+ query, schema, scope=ObjectScope.ANY
)
- return bool(cursor.first())
+ @reflection.cache
+ def has_table(self, connection, table_name, schema=None, **kw):
+ self._ensure_has_table_connection(connection)
+ query = self._has_table_query(schema)
+ return bool(connection.scalar(query, {"table_name": table_name}))
- def has_type(self, connection, type_name, schema=None):
- if schema is not None:
- query = """
- SELECT EXISTS (
- SELECT * FROM pg_catalog.pg_type t, pg_catalog.pg_namespace n
- WHERE t.typnamespace = n.oid
- AND t.typname = :typname
- AND n.nspname = :nspname
- )
- """
- query = sql.text(query)
- else:
- query = """
- SELECT EXISTS (
- SELECT * FROM pg_catalog.pg_type t
- WHERE t.typname = :typname
- AND pg_type_is_visible(t.oid)
- )
- """
- query = sql.text(query)
- query = query.bindparams(
- sql.bindparam("typname", str(type_name), type_=sqltypes.Unicode)
+ @reflection.cache
+ def has_sequence(self, connection, sequence_name, schema=None, **kw):
+ query = select(pg_catalog.pg_class.c.relname).where(
+ pg_catalog.pg_class.c.relkind == "S",
+ pg_catalog.pg_class.c.relname == sequence_name,
)
- if schema is not None:
- query = query.bindparams(
- sql.bindparam("nspname", str(schema), type_=sqltypes.Unicode)
+ query = self._pg_class_filter_scope_schema(
+ query, schema, scope=ObjectScope.ANY
+ )
+ return bool(connection.scalar(query))
+
+ @reflection.cache
+ def has_type(self, connection, type_name, schema=None, **kw):
+ query = (
+ select(pg_catalog.pg_type.c.typname)
+ .join(
+ pg_catalog.pg_namespace,
+ pg_catalog.pg_namespace.c.oid
+ == pg_catalog.pg_type.c.typnamespace,
)
- cursor = connection.execute(query)
- return bool(cursor.scalar())
+ .where(pg_catalog.pg_type.c.typname == type_name)
+ )
+ if schema is None:
+ query = query.where(
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid),
+ # ignore pg_catalog schema
+ pg_catalog.pg_namespace.c.nspname != "pg_catalog",
+ )
+ elif schema != "*":
+ query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
+
+ return bool(connection.scalar(query))
def _get_server_version_info(self, connection):
v = connection.exec_driver_sql("select pg_catalog.version()").scalar()
@@ -3502,229 +3028,300 @@ class PGDialect(default.DefaultDialect):
@reflection.cache
def get_table_oid(self, connection, table_name, schema=None, **kw):
- """Fetch the oid for schema.table_name.
-
- Several reflection methods require the table oid. The idea for using
- this method is that it can be fetched one time and cached for
- subsequent calls.
-
- """
- table_oid = None
- if schema is not None:
- schema_where_clause = "n.nspname = :schema"
- else:
- schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)"
- query = (
- """
- SELECT c.oid
- FROM pg_catalog.pg_class c
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace
- WHERE (%s)
- AND c.relname = :table_name AND c.relkind in
- ('r', 'v', 'm', 'f', 'p')
- """
- % schema_where_clause
+ """Fetch the oid for schema.table_name."""
+ query = select(pg_catalog.pg_class.c.oid).where(
+ pg_catalog.pg_class.c.relname == table_name,
+ self._pg_class_relkind_condition(
+ pg_catalog.RELKINDS_ALL_TABLE_LIKE
+ ),
)
- # Since we're binding to unicode, table_name and schema_name must be
- # unicode.
- table_name = str(table_name)
- if schema is not None:
- schema = str(schema)
- s = sql.text(query).bindparams(table_name=sqltypes.Unicode)
- s = s.columns(oid=sqltypes.Integer)
- if schema:
- s = s.bindparams(sql.bindparam("schema", type_=sqltypes.Unicode))
- c = connection.execute(s, dict(table_name=table_name, schema=schema))
- table_oid = c.scalar()
+ query = self._pg_class_filter_scope_schema(
+ query, schema, scope=ObjectScope.ANY
+ )
+ table_oid = connection.scalar(query)
if table_oid is None:
- raise exc.NoSuchTableError(table_name)
+ raise exc.NoSuchTableError(
+ f"{schema}.{table_name}" if schema else table_name
+ )
return table_oid
@reflection.cache
def get_schema_names(self, connection, **kw):
- result = connection.execute(
- sql.text(
- "SELECT nspname FROM pg_namespace "
- "WHERE nspname NOT LIKE 'pg_%' "
- "ORDER BY nspname"
- ).columns(nspname=sqltypes.Unicode)
+ query = (
+ select(pg_catalog.pg_namespace.c.nspname)
+ .where(pg_catalog.pg_namespace.c.nspname.not_like("pg_%"))
+ .order_by(pg_catalog.pg_namespace.c.nspname)
+ )
+ return connection.scalars(query).all()
+
+ def _get_relnames_for_relkinds(self, connection, schema, relkinds, scope):
+ query = select(pg_catalog.pg_class.c.relname).where(
+ self._pg_class_relkind_condition(relkinds)
)
- return [name for name, in result]
+ query = self._pg_class_filter_scope_schema(query, schema, scope=scope)
+ return connection.scalars(query).all()
@reflection.cache
def get_table_names(self, connection, schema=None, **kw):
- result = connection.execute(
- sql.text(
- "SELECT c.relname FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')"
- ).columns(relname=sqltypes.Unicode),
- dict(
- schema=schema
- if schema is not None
- else self.default_schema_name
- ),
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema,
+ pg_catalog.RELKINDS_TABLE_NO_FOREIGN,
+ scope=ObjectScope.DEFAULT,
+ )
+
+ @reflection.cache
+ def get_temp_table_names(self, connection, **kw):
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema=None,
+ relkinds=pg_catalog.RELKINDS_TABLE_NO_FOREIGN,
+ scope=ObjectScope.TEMPORARY,
)
- return [name for name, in result]
@reflection.cache
def _get_foreign_table_names(self, connection, schema=None, **kw):
- result = connection.execute(
- sql.text(
- "SELECT c.relname FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relkind = 'f'"
- ).columns(relname=sqltypes.Unicode),
- dict(
- schema=schema
- if schema is not None
- else self.default_schema_name
- ),
+ return self._get_relnames_for_relkinds(
+ connection, schema, relkinds=("f",), scope=ObjectScope.ANY
)
- return [name for name, in result]
@reflection.cache
- def get_view_names(
- self, connection, schema=None, include=("plain", "materialized"), **kw
- ):
+ def get_view_names(self, connection, schema=None, **kw):
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema,
+ pg_catalog.RELKINDS_VIEW,
+ scope=ObjectScope.DEFAULT,
+ )
- include_kind = {"plain": "v", "materialized": "m"}
- try:
- kinds = [include_kind[i] for i in util.to_list(include)]
- except KeyError:
- raise ValueError(
- "include %r unknown, needs to be a sequence containing "
- "one or both of 'plain' and 'materialized'" % (include,)
- )
- if not kinds:
- raise ValueError(
- "empty include, needs to be a sequence containing "
- "one or both of 'plain' and 'materialized'"
- )
+ @reflection.cache
+ def get_materialized_view_names(self, connection, schema=None, **kw):
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema,
+ pg_catalog.RELKINDS_MAT_VIEW,
+ scope=ObjectScope.DEFAULT,
+ )
- result = connection.execute(
- sql.text(
- "SELECT c.relname FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relkind IN (%s)"
- % (", ".join("'%s'" % elem for elem in kinds))
- ).columns(relname=sqltypes.Unicode),
- dict(
- schema=schema
- if schema is not None
- else self.default_schema_name
- ),
+ @reflection.cache
+ def get_temp_view_names(self, connection, schema=None, **kw):
+ return self._get_relnames_for_relkinds(
+ connection,
+ schema,
+ # NOTE: do not include temp materialzied views (that do not
+ # seem to be a thing at least up to version 14)
+ pg_catalog.RELKINDS_VIEW,
+ scope=ObjectScope.TEMPORARY,
)
- return [name for name, in result]
@reflection.cache
def get_sequence_names(self, connection, schema=None, **kw):
- if not schema:
- schema = self.default_schema_name
- cursor = connection.execute(
- sql.text(
- "SELECT relname FROM pg_class c join pg_namespace n on "
- "n.oid=c.relnamespace where relkind='S' and "
- "n.nspname=:schema"
- ).bindparams(
- sql.bindparam(
- "schema",
- str(schema),
- type_=sqltypes.Unicode,
- ),
- )
+ return self._get_relnames_for_relkinds(
+ connection, schema, relkinds=("S",), scope=ObjectScope.ANY
)
- return [row[0] for row in cursor]
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
- view_def = connection.scalar(
- sql.text(
- "SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c "
- "JOIN pg_namespace n ON n.oid = c.relnamespace "
- "WHERE n.nspname = :schema AND c.relname = :view_name "
- "AND c.relkind IN ('v', 'm')"
- ).columns(view_def=sqltypes.Unicode),
- dict(
- schema=schema
- if schema is not None
- else self.default_schema_name,
- view_name=view_name,
- ),
+ query = (
+ select(pg_catalog.pg_get_viewdef(pg_catalog.pg_class.c.oid))
+ .select_from(pg_catalog.pg_class)
+ .where(
+ pg_catalog.pg_class.c.relname == view_name,
+ self._pg_class_relkind_condition(
+ pg_catalog.RELKINDS_VIEW + pg_catalog.RELKINDS_MAT_VIEW
+ ),
+ )
)
- return view_def
+ query = self._pg_class_filter_scope_schema(
+ query, schema, scope=ObjectScope.ANY
+ )
+ res = connection.scalar(query)
+ if res is None:
+ raise exc.NoSuchTableError(
+ f"{schema}.{view_name}" if schema else view_name
+ )
+ else:
+ return res
+
+ def _value_or_raise(self, data, table, schema):
+ try:
+ return dict(data)[(schema, table)]
+ except KeyError:
+ raise exc.NoSuchTableError(
+ f"{schema}.{table}" if schema else table
+ ) from None
+
+ def _prepare_filter_names(self, filter_names):
+ if filter_names:
+ return True, {"filter_names": filter_names}
+ else:
+ return False, {}
+
+ def _kind_to_relkinds(self, kind: ObjectKind) -> tuple[str, ...]:
+ if kind is ObjectKind.ANY:
+ return pg_catalog.RELKINDS_ALL_TABLE_LIKE
+ relkinds = ()
+ if ObjectKind.TABLE in kind:
+ relkinds += pg_catalog.RELKINDS_TABLE
+ if ObjectKind.VIEW in kind:
+ relkinds += pg_catalog.RELKINDS_VIEW
+ if ObjectKind.MATERIALIZED_VIEW in kind:
+ relkinds += pg_catalog.RELKINDS_MAT_VIEW
+ return relkinds
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
-
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_columns(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
+ @lru_cache()
+ def _columns_query(self, schema, has_filter_names, scope, kind):
+ # NOTE: the query with the default and identity options scalar
+ # subquery is faster than trying to use outer joins for them
generated = (
- "a.attgenerated as generated"
+ pg_catalog.pg_attribute.c.attgenerated.label("generated")
if self.server_version_info >= (12,)
- else "NULL as generated"
+ else sql.null().label("generated")
)
if self.server_version_info >= (10,):
- # a.attidentity != '' is required or it will reflect also
- # serial columns as identity.
- identity = """\
- (SELECT json_build_object(
- 'always', a.attidentity = 'a',
- 'start', s.seqstart,
- 'increment', s.seqincrement,
- 'minvalue', s.seqmin,
- 'maxvalue', s.seqmax,
- 'cache', s.seqcache,
- 'cycle', s.seqcycle)
- FROM pg_catalog.pg_sequence s
- JOIN pg_catalog.pg_class c on s.seqrelid = c."oid"
- WHERE c.relkind = 'S'
- AND a.attidentity != ''
- AND s.seqrelid = pg_catalog.pg_get_serial_sequence(
- a.attrelid::regclass::text, a.attname
- )::regclass::oid
- ) as identity_options\
- """
+ # join lateral performs worse (~2x slower) than a scalar_subquery
+ identity = (
+ select(
+ sql.func.json_build_object(
+ "always",
+ pg_catalog.pg_attribute.c.attidentity == "a",
+ "start",
+ pg_catalog.pg_sequence.c.seqstart,
+ "increment",
+ pg_catalog.pg_sequence.c.seqincrement,
+ "minvalue",
+ pg_catalog.pg_sequence.c.seqmin,
+ "maxvalue",
+ pg_catalog.pg_sequence.c.seqmax,
+ "cache",
+ pg_catalog.pg_sequence.c.seqcache,
+ "cycle",
+ pg_catalog.pg_sequence.c.seqcycle,
+ )
+ )
+ .select_from(pg_catalog.pg_sequence)
+ .where(
+ # attidentity != '' is required or it will reflect also
+ # serial columns as identity.
+ pg_catalog.pg_attribute.c.attidentity != "",
+ pg_catalog.pg_sequence.c.seqrelid
+ == sql.cast(
+ sql.cast(
+ pg_catalog.pg_get_serial_sequence(
+ sql.cast(
+ sql.cast(
+ pg_catalog.pg_attribute.c.attrelid,
+ REGCLASS,
+ ),
+ TEXT,
+ ),
+ pg_catalog.pg_attribute.c.attname,
+ ),
+ REGCLASS,
+ ),
+ OID,
+ ),
+ )
+ .correlate(pg_catalog.pg_attribute)
+ .scalar_subquery()
+ .label("identity_options")
+ )
else:
- identity = "NULL as identity_options"
-
- SQL_COLS = """
- SELECT a.attname,
- pg_catalog.format_type(a.atttypid, a.atttypmod),
- (
- SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid)
- FROM pg_catalog.pg_attrdef d
- WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum
- AND a.atthasdef
- ) AS DEFAULT,
- a.attnotnull,
- a.attrelid as table_oid,
- pgd.description as comment,
- %s,
- %s
- FROM pg_catalog.pg_attribute a
- LEFT JOIN pg_catalog.pg_description pgd ON (
- pgd.objoid = a.attrelid AND pgd.objsubid = a.attnum)
- WHERE a.attrelid = :table_oid
- AND a.attnum > 0 AND NOT a.attisdropped
- ORDER BY a.attnum
- """ % (
- generated,
- identity,
+ identity = sql.null().label("identity_options")
+
+ # join lateral performs the same as scalar_subquery here
+ default = (
+ select(
+ pg_catalog.pg_get_expr(
+ pg_catalog.pg_attrdef.c.adbin,
+ pg_catalog.pg_attrdef.c.adrelid,
+ )
+ )
+ .select_from(pg_catalog.pg_attrdef)
+ .where(
+ pg_catalog.pg_attrdef.c.adrelid
+ == pg_catalog.pg_attribute.c.attrelid,
+ pg_catalog.pg_attrdef.c.adnum
+ == pg_catalog.pg_attribute.c.attnum,
+ pg_catalog.pg_attribute.c.atthasdef,
+ )
+ .correlate(pg_catalog.pg_attribute)
+ .scalar_subquery()
+ .label("default")
)
- s = (
- sql.text(SQL_COLS)
- .bindparams(sql.bindparam("table_oid", type_=sqltypes.Integer))
- .columns(attname=sqltypes.Unicode, default=sqltypes.Unicode)
+ relkinds = self._kind_to_relkinds(kind)
+ query = (
+ select(
+ pg_catalog.pg_attribute.c.attname.label("name"),
+ pg_catalog.format_type(
+ pg_catalog.pg_attribute.c.atttypid,
+ pg_catalog.pg_attribute.c.atttypmod,
+ ).label("format_type"),
+ default,
+ pg_catalog.pg_attribute.c.attnotnull.label("not_null"),
+ pg_catalog.pg_class.c.relname.label("table_name"),
+ pg_catalog.pg_description.c.description.label("comment"),
+ generated,
+ identity,
+ )
+ .select_from(pg_catalog.pg_class)
+ # NOTE: postgresql support table with no user column, meaning
+ # there is no row with pg_attribute.attnum > 0. use a left outer
+ # join to avoid filtering these tables.
+ .outerjoin(
+ pg_catalog.pg_attribute,
+ sql.and_(
+ pg_catalog.pg_class.c.oid
+ == pg_catalog.pg_attribute.c.attrelid,
+ pg_catalog.pg_attribute.c.attnum > 0,
+ ~pg_catalog.pg_attribute.c.attisdropped,
+ ),
+ )
+ .outerjoin(
+ pg_catalog.pg_description,
+ sql.and_(
+ pg_catalog.pg_description.c.objoid
+ == pg_catalog.pg_attribute.c.attrelid,
+ pg_catalog.pg_description.c.objsubid
+ == pg_catalog.pg_attribute.c.attnum,
+ ),
+ )
+ .where(self._pg_class_relkind_condition(relkinds))
+ .order_by(
+ pg_catalog.pg_class.c.relname, pg_catalog.pg_attribute.c.attnum
+ )
)
- c = connection.execute(s, dict(table_oid=table_oid))
- rows = c.fetchall()
+ query = self._pg_class_filter_scope_schema(query, schema, scope=scope)
+ if has_filter_names:
+ query = query.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
+ )
+ return query
+
+ def get_multi_columns(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._columns_query(schema, has_filter_names, scope, kind)
+ rows = connection.execute(query, params).mappings()
# dictionary with (name, ) if default search path or (schema, name)
# as keys
- domains = self._load_domains(connection)
+ domains = self._load_domains(
+ connection, info_cache=kw.get("info_cache")
+ )
# dictionary with (name, ) if default search path or (schema, name)
# as keys
@@ -3732,257 +3329,340 @@ class PGDialect(default.DefaultDialect):
((rec["name"],), rec)
if rec["visible"]
else ((rec["schema"], rec["name"]), rec)
- for rec in self._load_enums(connection, schema="*")
+ for rec in self._load_enums(
+ connection, schema="*", info_cache=kw.get("info_cache")
+ )
)
- # format columns
- columns = []
-
- for (
- name,
- format_type,
- default_,
- notnull,
- table_oid,
- comment,
- generated,
- identity,
- ) in rows:
- column_info = self._get_column_info(
- name,
- format_type,
- default_,
- notnull,
- domains,
- enums,
- schema,
- comment,
- generated,
- identity,
- )
- columns.append(column_info)
- return columns
+ columns = self._get_columns_info(rows, domains, enums, schema)
+
+ return columns.items()
+
+ def _get_columns_info(self, rows, domains, enums, schema):
+ array_type_pattern = re.compile(r"\[\]$")
+ attype_pattern = re.compile(r"\(.*\)")
+ charlen_pattern = re.compile(r"\(([\d,]+)\)")
+ args_pattern = re.compile(r"\((.*)\)")
+ args_split_pattern = re.compile(r"\s*,\s*")
- def _get_column_info(
- self,
- name,
- format_type,
- default,
- notnull,
- domains,
- enums,
- schema,
- comment,
- generated,
- identity,
- ):
def _handle_array_type(attype):
return (
# strip '[]' from integer[], etc.
- re.sub(r"\[\]$", "", attype),
+ array_type_pattern.sub("", attype),
attype.endswith("[]"),
)
- # strip (*) from character varying(5), timestamp(5)
- # with time zone, geometry(POLYGON), etc.
- attype = re.sub(r"\(.*\)", "", format_type)
+ columns = defaultdict(list)
+ for row_dict in rows:
+ # ensure that each table has an entry, even if it has no columns
+ if row_dict["name"] is None:
+ columns[
+ (schema, row_dict["table_name"])
+ ] = ReflectionDefaults.columns()
+ continue
+ table_cols = columns[(schema, row_dict["table_name"])]
- # strip '[]' from integer[], etc. and check if an array
- attype, is_array = _handle_array_type(attype)
+ format_type = row_dict["format_type"]
+ default = row_dict["default"]
+ name = row_dict["name"]
+ generated = row_dict["generated"]
+ identity = row_dict["identity_options"]
- # strip quotes from case sensitive enum or domain names
- enum_or_domain_key = tuple(util.quoted_token_parser(attype))
+ # strip (*) from character varying(5), timestamp(5)
+ # with time zone, geometry(POLYGON), etc.
+ attype = attype_pattern.sub("", format_type)
- nullable = not notnull
+ # strip '[]' from integer[], etc. and check if an array
+ attype, is_array = _handle_array_type(attype)
- charlen = re.search(r"\(([\d,]+)\)", format_type)
- if charlen:
- charlen = charlen.group(1)
- args = re.search(r"\((.*)\)", format_type)
- if args and args.group(1):
- args = tuple(re.split(r"\s*,\s*", args.group(1)))
- else:
- args = ()
- kwargs = {}
+ # strip quotes from case sensitive enum or domain names
+ enum_or_domain_key = tuple(util.quoted_token_parser(attype))
+
+ nullable = not row_dict["not_null"]
- if attype == "numeric":
+ charlen = charlen_pattern.search(format_type)
if charlen:
- prec, scale = charlen.split(",")
- args = (int(prec), int(scale))
+ charlen = charlen.group(1)
+ args = args_pattern.search(format_type)
+ if args and args.group(1):
+ args = tuple(args_split_pattern.split(args.group(1)))
else:
args = ()
- elif attype == "double precision":
- args = (53,)
- elif attype == "integer":
- args = ()
- elif attype in ("timestamp with time zone", "time with time zone"):
- kwargs["timezone"] = True
- if charlen:
- kwargs["precision"] = int(charlen)
- args = ()
- elif attype in (
- "timestamp without time zone",
- "time without time zone",
- "time",
- ):
- kwargs["timezone"] = False
- if charlen:
- kwargs["precision"] = int(charlen)
- args = ()
- elif attype == "bit varying":
- kwargs["varying"] = True
- if charlen:
+ kwargs = {}
+
+ if attype == "numeric":
+ if charlen:
+ prec, scale = charlen.split(",")
+ args = (int(prec), int(scale))
+ else:
+ args = ()
+ elif attype == "double precision":
+ args = (53,)
+ elif attype == "integer":
+ args = ()
+ elif attype in ("timestamp with time zone", "time with time zone"):
+ kwargs["timezone"] = True
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ args = ()
+ elif attype in (
+ "timestamp without time zone",
+ "time without time zone",
+ "time",
+ ):
+ kwargs["timezone"] = False
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ args = ()
+ elif attype == "bit varying":
+ kwargs["varying"] = True
+ if charlen:
+ args = (int(charlen),)
+ else:
+ args = ()
+ elif attype.startswith("interval"):
+ field_match = re.match(r"interval (.+)", attype, re.I)
+ if charlen:
+ kwargs["precision"] = int(charlen)
+ if field_match:
+ kwargs["fields"] = field_match.group(1)
+ attype = "interval"
+ args = ()
+ elif charlen:
args = (int(charlen),)
+
+ while True:
+ # looping here to suit nested domains
+ if attype in self.ischema_names:
+ coltype = self.ischema_names[attype]
+ break
+ elif enum_or_domain_key in enums:
+ enum = enums[enum_or_domain_key]
+ coltype = ENUM
+ kwargs["name"] = enum["name"]
+ if not enum["visible"]:
+ kwargs["schema"] = enum["schema"]
+ args = tuple(enum["labels"])
+ break
+ elif enum_or_domain_key in domains:
+ domain = domains[enum_or_domain_key]
+ attype = domain["attype"]
+ attype, is_array = _handle_array_type(attype)
+ # strip quotes from case sensitive enum or domain names
+ enum_or_domain_key = tuple(
+ util.quoted_token_parser(attype)
+ )
+ # A table can't override a not null on the domain,
+ # but can override nullable
+ nullable = nullable and domain["nullable"]
+ if domain["default"] and not default:
+ # It can, however, override the default
+ # value, but can't set it to null.
+ default = domain["default"]
+ continue
+ else:
+ coltype = None
+ break
+
+ if coltype:
+ coltype = coltype(*args, **kwargs)
+ if is_array:
+ coltype = self.ischema_names["_array"](coltype)
else:
- args = ()
- elif attype.startswith("interval"):
- field_match = re.match(r"interval (.+)", attype, re.I)
- if charlen:
- kwargs["precision"] = int(charlen)
- if field_match:
- kwargs["fields"] = field_match.group(1)
- attype = "interval"
- args = ()
- elif charlen:
- args = (int(charlen),)
-
- while True:
- # looping here to suit nested domains
- if attype in self.ischema_names:
- coltype = self.ischema_names[attype]
- break
- elif enum_or_domain_key in enums:
- enum = enums[enum_or_domain_key]
- coltype = ENUM
- kwargs["name"] = enum["name"]
- if not enum["visible"]:
- kwargs["schema"] = enum["schema"]
- args = tuple(enum["labels"])
- break
- elif enum_or_domain_key in domains:
- domain = domains[enum_or_domain_key]
- attype = domain["attype"]
- attype, is_array = _handle_array_type(attype)
- # strip quotes from case sensitive enum or domain names
- enum_or_domain_key = tuple(util.quoted_token_parser(attype))
- # A table can't override a not null on the domain,
- # but can override nullable
- nullable = nullable and domain["nullable"]
- if domain["default"] and not default:
- # It can, however, override the default
- # value, but can't set it to null.
- default = domain["default"]
- continue
+ util.warn(
+ "Did not recognize type '%s' of column '%s'"
+ % (attype, name)
+ )
+ coltype = sqltypes.NULLTYPE
+
+ # If a zero byte or blank string depending on driver (is also
+ # absent for older PG versions), then not a generated column.
+ # Otherwise, s = stored. (Other values might be added in the
+ # future.)
+ if generated not in (None, "", b"\x00"):
+ computed = dict(
+ sqltext=default, persisted=generated in ("s", b"s")
+ )
+ default = None
else:
- coltype = None
- break
+ computed = None
- if coltype:
- coltype = coltype(*args, **kwargs)
- if is_array:
- coltype = self.ischema_names["_array"](coltype)
- else:
- util.warn(
- "Did not recognize type '%s' of column '%s'" % (attype, name)
+ # adjust the default value
+ autoincrement = False
+ if default is not None:
+ match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
+ if match is not None:
+ if issubclass(coltype._type_affinity, sqltypes.Integer):
+ autoincrement = True
+ # the default is related to a Sequence
+ if "." not in match.group(2) and schema is not None:
+ # unconditionally quote the schema name. this could
+ # later be enhanced to obey quoting rules /
+ # "quote schema"
+ default = (
+ match.group(1)
+ + ('"%s"' % schema)
+ + "."
+ + match.group(2)
+ + match.group(3)
+ )
+
+ column_info = {
+ "name": name,
+ "type": coltype,
+ "nullable": nullable,
+ "default": default,
+ "autoincrement": autoincrement or identity is not None,
+ "comment": row_dict["comment"],
+ }
+ if computed is not None:
+ column_info["computed"] = computed
+ if identity is not None:
+ column_info["identity"] = identity
+
+ table_cols.append(column_info)
+
+ return columns
+
+ @lru_cache()
+ def _table_oids_query(self, schema, has_filter_names, scope, kind):
+ relkinds = self._kind_to_relkinds(kind)
+ oid_q = select(
+ pg_catalog.pg_class.c.oid, pg_catalog.pg_class.c.relname
+ ).where(self._pg_class_relkind_condition(relkinds))
+ oid_q = self._pg_class_filter_scope_schema(oid_q, schema, scope=scope)
+
+ if has_filter_names:
+ oid_q = oid_q.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
)
- coltype = sqltypes.NULLTYPE
-
- # If a zero byte or blank string depending on driver (is also absent
- # for older PG versions), then not a generated column. Otherwise, s =
- # stored. (Other values might be added in the future.)
- if generated not in (None, "", b"\x00"):
- computed = dict(
- sqltext=default, persisted=generated in ("s", b"s")
+ return oid_q
+
+ @reflection.flexi_cache(
+ ("schema", InternalTraversal.dp_string),
+ ("filter_names", InternalTraversal.dp_string_list),
+ ("kind", InternalTraversal.dp_plain_obj),
+ ("scope", InternalTraversal.dp_plain_obj),
+ )
+ def _get_table_oids(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ oid_q = self._table_oids_query(schema, has_filter_names, scope, kind)
+ result = connection.execute(oid_q, params)
+ return result.all()
+
+ @util.memoized_property
+ def _constraint_query(self):
+ con_sq = (
+ select(
+ pg_catalog.pg_constraint.c.conrelid,
+ pg_catalog.pg_constraint.c.conname,
+ sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label(
+ "attnum"
+ ),
+ sql.func.generate_subscripts(
+ pg_catalog.pg_constraint.c.conkey, 1
+ ).label("ord"),
)
- default = None
- else:
- computed = None
-
- # adjust the default value
- autoincrement = False
- if default is not None:
- match = re.search(r"""(nextval\(')([^']+)('.*$)""", default)
- if match is not None:
- if issubclass(coltype._type_affinity, sqltypes.Integer):
- autoincrement = True
- # the default is related to a Sequence
- sch = schema
- if "." not in match.group(2) and sch is not None:
- # unconditionally quote the schema name. this could
- # later be enhanced to obey quoting rules /
- # "quote schema"
- default = (
- match.group(1)
- + ('"%s"' % sch)
- + "."
- + match.group(2)
- + match.group(3)
- )
+ .where(
+ pg_catalog.pg_constraint.c.contype == bindparam("contype"),
+ pg_catalog.pg_constraint.c.conrelid.in_(bindparam("oids")),
+ )
+ .subquery("con")
+ )
- column_info = dict(
- name=name,
- type=coltype,
- nullable=nullable,
- default=default,
- autoincrement=autoincrement or identity is not None,
- comment=comment,
+ attr_sq = (
+ select(
+ con_sq.c.conrelid,
+ con_sq.c.conname,
+ pg_catalog.pg_attribute.c.attname,
+ )
+ .select_from(pg_catalog.pg_attribute)
+ .join(
+ con_sq,
+ sql.and_(
+ pg_catalog.pg_attribute.c.attnum == con_sq.c.attnum,
+ pg_catalog.pg_attribute.c.attrelid == con_sq.c.conrelid,
+ ),
+ )
+ .order_by(con_sq.c.conname, con_sq.c.ord)
+ .subquery("attr")
)
- if computed is not None:
- column_info["computed"] = computed
- if identity is not None:
- column_info["identity"] = identity
- return column_info
- @reflection.cache
- def get_pk_constraint(self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ return (
+ select(
+ attr_sq.c.conrelid,
+ sql.func.array_agg(attr_sq.c.attname).label("cols"),
+ attr_sq.c.conname,
+ )
+ .group_by(attr_sq.c.conrelid, attr_sq.c.conname)
+ .order_by(attr_sq.c.conrelid, attr_sq.c.conname)
)
- if self.server_version_info < (8, 4):
- PK_SQL = """
- SELECT a.attname
- FROM
- pg_class t
- join pg_index ix on t.oid = ix.indrelid
- join pg_attribute a
- on t.oid=a.attrelid AND %s
- WHERE
- t.oid = :table_oid and ix.indisprimary = 't'
- ORDER BY a.attnum
- """ % self._pg_index_any(
- "a.attnum", "ix.indkey"
+ def _reflect_constraint(
+ self, connection, contype, schema, filter_names, scope, kind, **kw
+ ):
+ table_oids = self._get_table_oids(
+ connection, schema, filter_names, scope, kind, **kw
+ )
+ batches = list(table_oids)
+
+ while batches:
+ batch = batches[0:3000]
+ batches[0:3000] = []
+
+ result = connection.execute(
+ self._constraint_query,
+ {"oids": [r[0] for r in batch], "contype": contype},
)
- else:
- # unnest() and generate_subscripts() both introduced in
- # version 8.4
- PK_SQL = """
- SELECT a.attname
- FROM pg_attribute a JOIN (
- SELECT unnest(ix.indkey) attnum,
- generate_subscripts(ix.indkey, 1) ord
- FROM pg_index ix
- WHERE ix.indrelid = :table_oid AND ix.indisprimary
- ) k ON a.attnum=k.attnum
- WHERE a.attrelid = :table_oid
- ORDER BY k.ord
- """
- t = sql.text(PK_SQL).columns(attname=sqltypes.Unicode)
- c = connection.execute(t, dict(table_oid=table_oid))
- cols = [r[0] for r in c.fetchall()]
-
- PK_CONS_SQL = """
- SELECT conname
- FROM pg_catalog.pg_constraint r
- WHERE r.conrelid = :table_oid AND r.contype = 'p'
- ORDER BY 1
- """
- t = sql.text(PK_CONS_SQL).columns(conname=sqltypes.Unicode)
- c = connection.execute(t, dict(table_oid=table_oid))
- name = c.scalar()
+ result_by_oid = defaultdict(list)
+ for oid, cols, constraint_name in result:
+ result_by_oid[oid].append((cols, constraint_name))
+
+ for oid, tablename in batch:
+ for_oid = result_by_oid.get(oid, ())
+ if for_oid:
+ for cols, constraint in for_oid:
+ yield tablename, cols, constraint
+ else:
+ yield tablename, None, None
+
+ @reflection.cache
+ def get_pk_constraint(self, connection, table_name, schema=None, **kw):
+ data = self.get_multi_pk_constraint(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
+ )
+ return self._value_or_raise(data, table_name, schema)
+
+ def get_multi_pk_constraint(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ result = self._reflect_constraint(
+ connection, "p", schema, filter_names, scope, kind, **kw
+ )
- return {"constrained_columns": cols, "name": name}
+ # only a single pk can be present for each table. Return an entry
+ # even if a table has no primary key
+ default = ReflectionDefaults.pk_constraint
+ return (
+ (
+ (schema, table_name),
+ {
+ "constrained_columns": [] if cols is None else cols,
+ "name": pk_name,
+ }
+ if pk_name is not None
+ else default(),
+ )
+ for (table_name, cols, pk_name) in result
+ )
@reflection.cache
def get_foreign_keys(
@@ -3993,27 +3673,71 @@ class PGDialect(default.DefaultDialect):
postgresql_ignore_search_path=False,
**kw,
):
- preparer = self.identifier_preparer
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_foreign_keys(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ postgresql_ignore_search_path=postgresql_ignore_search_path,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- FK_SQL = """
- SELECT r.conname,
- pg_catalog.pg_get_constraintdef(r.oid, true) as condef,
- n.nspname as conschema
- FROM pg_catalog.pg_constraint r,
- pg_namespace n,
- pg_class c
-
- WHERE r.conrelid = :table AND
- r.contype = 'f' AND
- c.oid = confrelid AND
- n.oid = c.relnamespace
- ORDER BY 1
- """
- # https://www.postgresql.org/docs/9.0/static/sql-createtable.html
- FK_REGEX = re.compile(
+ @lru_cache()
+ def _foreing_key_query(self, schema, has_filter_names, scope, kind):
+ pg_class_ref = pg_catalog.pg_class.alias("cls_ref")
+ pg_namespace_ref = pg_catalog.pg_namespace.alias("nsp_ref")
+ relkinds = self._kind_to_relkinds(kind)
+ query = (
+ select(
+ pg_catalog.pg_class.c.relname,
+ pg_catalog.pg_constraint.c.conname,
+ sql.case(
+ (
+ pg_catalog.pg_constraint.c.oid.is_not(None),
+ pg_catalog.pg_get_constraintdef(
+ pg_catalog.pg_constraint.c.oid, True
+ ),
+ ),
+ else_=None,
+ ),
+ pg_namespace_ref.c.nspname,
+ )
+ .select_from(pg_catalog.pg_class)
+ .outerjoin(
+ pg_catalog.pg_constraint,
+ sql.and_(
+ pg_catalog.pg_class.c.oid
+ == pg_catalog.pg_constraint.c.conrelid,
+ pg_catalog.pg_constraint.c.contype == "f",
+ ),
+ )
+ .outerjoin(
+ pg_class_ref,
+ pg_class_ref.c.oid == pg_catalog.pg_constraint.c.confrelid,
+ )
+ .outerjoin(
+ pg_namespace_ref,
+ pg_class_ref.c.relnamespace == pg_namespace_ref.c.oid,
+ )
+ .order_by(
+ pg_catalog.pg_class.c.relname,
+ pg_catalog.pg_constraint.c.conname,
+ )
+ .where(self._pg_class_relkind_condition(relkinds))
+ )
+ query = self._pg_class_filter_scope_schema(query, schema, scope)
+ if has_filter_names:
+ query = query.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
+ )
+ return query
+
+ @util.memoized_property
+ def _fk_regex_pattern(self):
+ # https://www.postgresql.org/docs/14.0/static/sql-createtable.html
+ return re.compile(
r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)"
r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?"
r"[\s]?(ON UPDATE "
@@ -4024,12 +3748,33 @@ class PGDialect(default.DefaultDialect):
r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?"
)
- t = sql.text(FK_SQL).columns(
- conname=sqltypes.Unicode, condef=sqltypes.Unicode
- )
- c = connection.execute(t, dict(table=table_oid))
- fkeys = []
- for conname, condef, conschema in c.fetchall():
+ def get_multi_foreign_keys(
+ self,
+ connection,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ postgresql_ignore_search_path=False,
+ **kw,
+ ):
+ preparer = self.identifier_preparer
+
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._foreing_key_query(schema, has_filter_names, scope, kind)
+ result = connection.execute(query, params)
+
+ FK_REGEX = self._fk_regex_pattern
+
+ fkeys = defaultdict(list)
+ default = ReflectionDefaults.foreign_keys
+ for table_name, conname, condef, conschema in result:
+ # ensure that each table has an entry, even if it has
+ # no foreign keys
+ if conname is None:
+ fkeys[(schema, table_name)] = default()
+ continue
+ table_fks = fkeys[(schema, table_name)]
m = re.search(FK_REGEX, condef).groups()
(
@@ -4096,317 +3841,406 @@ class PGDialect(default.DefaultDialect):
"referred_columns": referred_columns,
"options": options,
}
- fkeys.append(fkey_d)
- return fkeys
-
- def _pg_index_any(self, col, compare_to):
- if self.server_version_info < (8, 1):
- # https://www.postgresql.org/message-id/10279.1124395722@sss.pgh.pa.us
- # "In CVS tip you could replace this with "attnum = ANY (indkey)".
- # Unfortunately, most array support doesn't work on int2vector in
- # pre-8.1 releases, so I think you're kinda stuck with the above
- # for now.
- # regards, tom lane"
- return "(%s)" % " OR ".join(
- "%s[%d] = %s" % (compare_to, ind, col) for ind in range(0, 10)
- )
- else:
- return "%s = ANY(%s)" % (col, compare_to)
+ table_fks.append(fkey_d)
+ return fkeys.items()
@reflection.cache
- def get_indexes(self, connection, table_name, schema, **kw):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ def get_indexes(self, connection, table_name, schema=None, **kw):
+ data = self.get_multi_indexes(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- # cast indkey as varchar since it's an int2vector,
- # returned as a list by some drivers such as pypostgresql
-
- if self.server_version_info < (8, 5):
- IDX_SQL = """
- SELECT
- i.relname as relname,
- ix.indisunique, ix.indexprs, ix.indpred,
- a.attname, a.attnum, NULL, ix.indkey%s,
- %s, %s, am.amname,
- NULL as indnkeyatts
- FROM
- pg_class t
- join pg_index ix on t.oid = ix.indrelid
- join pg_class i on i.oid = ix.indexrelid
- left outer join
- pg_attribute a
- on t.oid = a.attrelid and %s
- left outer join
- pg_am am
- on i.relam = am.oid
- WHERE
- t.relkind IN ('r', 'v', 'f', 'm')
- and t.oid = :table_oid
- and ix.indisprimary = 'f'
- ORDER BY
- t.relname,
- i.relname
- """ % (
- # version 8.3 here was based on observing the
- # cast does not work in PG 8.2.4, does work in 8.3.0.
- # nothing in PG changelogs regarding this.
- "::varchar" if self.server_version_info >= (8, 3) else "",
- "ix.indoption::varchar"
- if self.server_version_info >= (8, 3)
- else "NULL",
- "i.reloptions"
- if self.server_version_info >= (8, 2)
- else "NULL",
- self._pg_index_any("a.attnum", "ix.indkey"),
+ @util.memoized_property
+ def _index_query(self):
+ pg_class_index = pg_catalog.pg_class.alias("cls_idx")
+ # NOTE: repeating oids clause improve query performance
+
+ # subquery to get the columns
+ idx_sq = (
+ select(
+ pg_catalog.pg_index.c.indexrelid,
+ pg_catalog.pg_index.c.indrelid,
+ sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"),
+ sql.func.generate_subscripts(
+ pg_catalog.pg_index.c.indkey, 1
+ ).label("ord"),
)
- else:
- IDX_SQL = """
- SELECT
- i.relname as relname,
- ix.indisunique, ix.indexprs,
- a.attname, a.attnum, c.conrelid, ix.indkey::varchar,
- ix.indoption::varchar, i.reloptions, am.amname,
- pg_get_expr(ix.indpred, ix.indrelid),
- %s as indnkeyatts
- FROM
- pg_class t
- join pg_index ix on t.oid = ix.indrelid
- join pg_class i on i.oid = ix.indexrelid
- left outer join
- pg_attribute a
- on t.oid = a.attrelid and a.attnum = ANY(ix.indkey)
- left outer join
- pg_constraint c
- on (ix.indrelid = c.conrelid and
- ix.indexrelid = c.conindid and
- c.contype in ('p', 'u', 'x'))
- left outer join
- pg_am am
- on i.relam = am.oid
- WHERE
- t.relkind IN ('r', 'v', 'f', 'm', 'p')
- and t.oid = :table_oid
- and ix.indisprimary = 'f'
- ORDER BY
- t.relname,
- i.relname
- """ % (
- "ix.indnkeyatts"
- if self.server_version_info >= (11, 0)
- else "NULL",
+ .where(
+ ~pg_catalog.pg_index.c.indisprimary,
+ pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")),
)
+ .subquery("idx")
+ )
- t = sql.text(IDX_SQL).columns(
- relname=sqltypes.Unicode, attname=sqltypes.Unicode
+ attr_sq = (
+ select(
+ idx_sq.c.indexrelid,
+ idx_sq.c.indrelid,
+ pg_catalog.pg_attribute.c.attname,
+ )
+ .select_from(pg_catalog.pg_attribute)
+ .join(
+ idx_sq,
+ sql.and_(
+ pg_catalog.pg_attribute.c.attnum == idx_sq.c.attnum,
+ pg_catalog.pg_attribute.c.attrelid == idx_sq.c.indrelid,
+ ),
+ )
+ .where(idx_sq.c.indrelid.in_(bindparam("oids")))
+ .order_by(idx_sq.c.indexrelid, idx_sq.c.ord)
+ .subquery("idx_attr")
)
- c = connection.execute(t, dict(table_oid=table_oid))
- indexes = defaultdict(lambda: defaultdict(dict))
+ cols_sq = (
+ select(
+ attr_sq.c.indexrelid,
+ attr_sq.c.indrelid,
+ sql.func.array_agg(attr_sq.c.attname).label("cols"),
+ )
+ .group_by(attr_sq.c.indexrelid, attr_sq.c.indrelid)
+ .subquery("idx_cols")
+ )
- sv_idx_name = None
- for row in c.fetchall():
- (
- idx_name,
- unique,
- expr,
- col,
- col_num,
- conrelid,
- idx_key,
- idx_option,
- options,
- amname,
- filter_definition,
- indnkeyatts,
- ) = row
+ if self.server_version_info >= (11, 0):
+ indnkeyatts = pg_catalog.pg_index.c.indnkeyatts
+ else:
+ indnkeyatts = sql.null().label("indnkeyatts")
- if expr:
- if idx_name != sv_idx_name:
- util.warn(
- "Skipped unsupported reflection of "
- "expression-based index %s" % idx_name
- )
- sv_idx_name = idx_name
- continue
+ query = (
+ select(
+ pg_catalog.pg_index.c.indrelid,
+ pg_class_index.c.relname.label("relname_index"),
+ pg_catalog.pg_index.c.indisunique,
+ pg_catalog.pg_index.c.indexprs,
+ pg_catalog.pg_constraint.c.conrelid.is_not(None).label(
+ "has_constraint"
+ ),
+ pg_catalog.pg_index.c.indoption,
+ pg_class_index.c.reloptions,
+ pg_catalog.pg_am.c.amname,
+ pg_catalog.pg_get_expr(
+ pg_catalog.pg_index.c.indpred,
+ pg_catalog.pg_index.c.indrelid,
+ ).label("filter_definition"),
+ indnkeyatts,
+ cols_sq.c.cols.label("index_cols"),
+ )
+ .select_from(pg_catalog.pg_index)
+ .where(
+ pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")),
+ ~pg_catalog.pg_index.c.indisprimary,
+ )
+ .join(
+ pg_class_index,
+ pg_catalog.pg_index.c.indexrelid == pg_class_index.c.oid,
+ )
+ .join(
+ pg_catalog.pg_am,
+ pg_class_index.c.relam == pg_catalog.pg_am.c.oid,
+ )
+ .outerjoin(
+ cols_sq,
+ pg_catalog.pg_index.c.indexrelid == cols_sq.c.indexrelid,
+ )
+ .outerjoin(
+ pg_catalog.pg_constraint,
+ sql.and_(
+ pg_catalog.pg_index.c.indrelid
+ == pg_catalog.pg_constraint.c.conrelid,
+ pg_catalog.pg_index.c.indexrelid
+ == pg_catalog.pg_constraint.c.conindid,
+ pg_catalog.pg_constraint.c.contype
+ == sql.any_(_array.array(("p", "u", "x"))),
+ ),
+ )
+ .order_by(pg_catalog.pg_index.c.indrelid, pg_class_index.c.relname)
+ )
+ return query
- has_idx = idx_name in indexes
- index = indexes[idx_name]
- if col is not None:
- index["cols"][col_num] = col
- if not has_idx:
- idx_keys = idx_key.split()
- # "The number of key columns in the index, not counting any
- # included columns, which are merely stored and do not
- # participate in the index semantics"
- if indnkeyatts and idx_keys[indnkeyatts:]:
- # this is a "covering index" which has INCLUDE columns
- # as well as regular index columns
- inc_keys = idx_keys[indnkeyatts:]
- idx_keys = idx_keys[:indnkeyatts]
- else:
- inc_keys = []
+ def get_multi_indexes(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
- index["key"] = [int(k.strip()) for k in idx_keys]
- index["inc"] = [int(k.strip()) for k in inc_keys]
+ table_oids = self._get_table_oids(
+ connection, schema, filter_names, scope, kind, **kw
+ )
- # (new in pg 8.3)
- # "pg_index.indoption" is list of ints, one per column/expr.
- # int acts as bitmask: 0x01=DESC, 0x02=NULLSFIRST
- sorting = {}
- for col_idx, col_flags in enumerate(
- (idx_option or "").split()
- ):
- col_flags = int(col_flags.strip())
- col_sorting = ()
- # try to set flags only if they differ from PG defaults...
- if col_flags & 0x01:
- col_sorting += ("desc",)
- if not (col_flags & 0x02):
- col_sorting += ("nulls_last",)
+ indexes = defaultdict(list)
+ default = ReflectionDefaults.indexes
+
+ batches = list(table_oids)
+
+ while batches:
+ batch = batches[0:3000]
+ batches[0:3000] = []
+
+ result = connection.execute(
+ self._index_query, {"oids": [r[0] for r in batch]}
+ ).mappings()
+
+ result_by_oid = defaultdict(list)
+ for row_dict in result:
+ result_by_oid[row_dict["indrelid"]].append(row_dict)
+
+ for oid, table_name in batch:
+ if oid not in result_by_oid:
+ # ensure that each table has an entry, even if reflection
+ # is skipped because not supported
+ indexes[(schema, table_name)] = default()
+ continue
+
+ for row in result_by_oid[oid]:
+ index_name = row["relname_index"]
+
+ table_indexes = indexes[(schema, table_name)]
+
+ if row["indexprs"]:
+ tn = (
+ table_name
+ if schema is None
+ else f"{schema}.{table_name}"
+ )
+ util.warn(
+ "Skipped unsupported reflection of "
+ f"expression-based index {index_name} of "
+ f"table {tn}"
+ )
+ continue
+
+ all_cols = row["index_cols"]
+ indnkeyatts = row["indnkeyatts"]
+ # "The number of key columns in the index, not counting any
+ # included columns, which are merely stored and do not
+ # participate in the index semantics"
+ if indnkeyatts and all_cols[indnkeyatts:]:
+ # this is a "covering index" which has INCLUDE columns
+ # as well as regular index columns
+ inc_cols = all_cols[indnkeyatts:]
+ idx_cols = all_cols[:indnkeyatts]
else:
- if col_flags & 0x02:
- col_sorting += ("nulls_first",)
- if col_sorting:
- sorting[col_idx] = col_sorting
- if sorting:
- index["sorting"] = sorting
-
- index["unique"] = unique
- if conrelid is not None:
- index["duplicates_constraint"] = idx_name
- if options:
- index["options"] = dict(
- [option.split("=") for option in options]
- )
-
- # it *might* be nice to include that this is 'btree' in the
- # reflection info. But we don't want an Index object
- # to have a ``postgresql_using`` in it that is just the
- # default, so for the moment leaving this out.
- if amname and amname != "btree":
- index["amname"] = amname
-
- if filter_definition:
- index["postgresql_where"] = filter_definition
+ idx_cols = all_cols
+ inc_cols = []
+
+ index = {
+ "name": index_name,
+ "unique": row["indisunique"],
+ "column_names": idx_cols,
+ }
+
+ sorting = {}
+ for col_index, col_flags in enumerate(row["indoption"]):
+ col_sorting = ()
+ # try to set flags only if they differ from PG
+ # defaults...
+ if col_flags & 0x01:
+ col_sorting += ("desc",)
+ if not (col_flags & 0x02):
+ col_sorting += ("nulls_last",)
+ else:
+ if col_flags & 0x02:
+ col_sorting += ("nulls_first",)
+ if col_sorting:
+ sorting[idx_cols[col_index]] = col_sorting
+ if sorting:
+ index["column_sorting"] = sorting
+ if row["has_constraint"]:
+ index["duplicates_constraint"] = index_name
+
+ dialect_options = {}
+ if row["reloptions"]:
+ dialect_options["postgresql_with"] = dict(
+ [option.split("=") for option in row["reloptions"]]
+ )
+ # it *might* be nice to include that this is 'btree' in the
+ # reflection info. But we don't want an Index object
+ # to have a ``postgresql_using`` in it that is just the
+ # default, so for the moment leaving this out.
+ amname = row["amname"]
+ if amname != "btree":
+ dialect_options["postgresql_using"] = row["amname"]
+ if row["filter_definition"]:
+ dialect_options["postgresql_where"] = row[
+ "filter_definition"
+ ]
+ if self.server_version_info >= (11, 0):
+ # NOTE: this is legacy, this is part of
+ # dialect_options now as of #7382
+ index["include_columns"] = inc_cols
+ dialect_options["postgresql_include"] = inc_cols
+ if dialect_options:
+ index["dialect_options"] = dialect_options
- result = []
- for name, idx in indexes.items():
- entry = {
- "name": name,
- "unique": idx["unique"],
- "column_names": [idx["cols"][i] for i in idx["key"]],
- }
- if self.server_version_info >= (11, 0):
- # NOTE: this is legacy, this is part of dialect_options now
- # as of #7382
- entry["include_columns"] = [idx["cols"][i] for i in idx["inc"]]
- if "duplicates_constraint" in idx:
- entry["duplicates_constraint"] = idx["duplicates_constraint"]
- if "sorting" in idx:
- entry["column_sorting"] = dict(
- (idx["cols"][idx["key"][i]], value)
- for i, value in idx["sorting"].items()
- )
- if "include_columns" in entry:
- entry.setdefault("dialect_options", {})[
- "postgresql_include"
- ] = entry["include_columns"]
- if "options" in idx:
- entry.setdefault("dialect_options", {})[
- "postgresql_with"
- ] = idx["options"]
- if "amname" in idx:
- entry.setdefault("dialect_options", {})[
- "postgresql_using"
- ] = idx["amname"]
- if "postgresql_where" in idx:
- entry.setdefault("dialect_options", {})[
- "postgresql_where"
- ] = idx["postgresql_where"]
- result.append(entry)
- return result
+ table_indexes.append(index)
+ return indexes.items()
@reflection.cache
def get_unique_constraints(
self, connection, table_name, schema=None, **kw
):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_unique_constraints(
+ connection,
+ schema=schema,
+ filter_names=[table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- UNIQUE_SQL = """
- SELECT
- cons.conname as name,
- cons.conkey as key,
- a.attnum as col_num,
- a.attname as col_name
- FROM
- pg_catalog.pg_constraint cons
- join pg_attribute a
- on cons.conrelid = a.attrelid AND
- a.attnum = ANY(cons.conkey)
- WHERE
- cons.conrelid = :table_oid AND
- cons.contype = 'u'
- """
-
- t = sql.text(UNIQUE_SQL).columns(col_name=sqltypes.Unicode)
- c = connection.execute(t, dict(table_oid=table_oid))
+ def get_multi_unique_constraints(
+ self,
+ connection,
+ schema,
+ filter_names,
+ scope,
+ kind,
+ **kw,
+ ):
+ result = self._reflect_constraint(
+ connection, "u", schema, filter_names, scope, kind, **kw
+ )
- uniques = defaultdict(lambda: defaultdict(dict))
- for row in c.fetchall():
- uc = uniques[row.name]
- uc["key"] = row.key
- uc["cols"][row.col_num] = row.col_name
+ # each table can have multiple unique constraints
+ uniques = defaultdict(list)
+ default = ReflectionDefaults.unique_constraints
+ for (table_name, cols, con_name) in result:
+ # ensure a list is created for each table. leave it empty if
+ # the table has no unique cosntraint
+ if con_name is None:
+ uniques[(schema, table_name)] = default()
+ continue
- return [
- {"name": name, "column_names": [uc["cols"][i] for i in uc["key"]]}
- for name, uc in uniques.items()
- ]
+ uniques[(schema, table_name)].append(
+ {
+ "column_names": cols,
+ "name": con_name,
+ }
+ )
+ return uniques.items()
@reflection.cache
def get_table_comment(self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_table_comment(
+ connection,
+ schema,
+ [table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- COMMENT_SQL = """
- SELECT
- pgd.description as table_comment
- FROM
- pg_catalog.pg_description pgd
- WHERE
- pgd.objsubid = 0 AND
- pgd.objoid = :table_oid
- """
+ @lru_cache()
+ def _comment_query(self, schema, has_filter_names, scope, kind):
+ relkinds = self._kind_to_relkinds(kind)
+ query = (
+ select(
+ pg_catalog.pg_class.c.relname,
+ pg_catalog.pg_description.c.description,
+ )
+ .select_from(pg_catalog.pg_class)
+ .outerjoin(
+ pg_catalog.pg_description,
+ sql.and_(
+ pg_catalog.pg_class.c.oid
+ == pg_catalog.pg_description.c.objoid,
+ pg_catalog.pg_description.c.objsubid == 0,
+ ),
+ )
+ .where(self._pg_class_relkind_condition(relkinds))
+ )
+ query = self._pg_class_filter_scope_schema(query, schema, scope)
+ if has_filter_names:
+ query = query.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
+ )
+ return query
- c = connection.execute(
- sql.text(COMMENT_SQL), dict(table_oid=table_oid)
+ def get_multi_table_comment(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._comment_query(schema, has_filter_names, scope, kind)
+ result = connection.execute(query, params)
+
+ default = ReflectionDefaults.table_comment
+ return (
+ (
+ (schema, table),
+ {"text": comment} if comment is not None else default(),
+ )
+ for table, comment in result
)
- return {"text": c.scalar()}
@reflection.cache
def get_check_constraints(self, connection, table_name, schema=None, **kw):
- table_oid = self.get_table_oid(
- connection, table_name, schema, info_cache=kw.get("info_cache")
+ data = self.get_multi_check_constraints(
+ connection,
+ schema,
+ [table_name],
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ **kw,
)
+ return self._value_or_raise(data, table_name, schema)
- CHECK_SQL = """
- SELECT
- cons.conname as name,
- pg_get_constraintdef(cons.oid) as src
- FROM
- pg_catalog.pg_constraint cons
- WHERE
- cons.conrelid = :table_oid AND
- cons.contype = 'c'
- """
-
- c = connection.execute(sql.text(CHECK_SQL), dict(table_oid=table_oid))
+ @lru_cache()
+ def _check_constraint_query(self, schema, has_filter_names, scope, kind):
+ relkinds = self._kind_to_relkinds(kind)
+ query = (
+ select(
+ pg_catalog.pg_class.c.relname,
+ pg_catalog.pg_constraint.c.conname,
+ sql.case(
+ (
+ pg_catalog.pg_constraint.c.oid.is_not(None),
+ pg_catalog.pg_get_constraintdef(
+ pg_catalog.pg_constraint.c.oid
+ ),
+ ),
+ else_=None,
+ ),
+ )
+ .select_from(pg_catalog.pg_class)
+ .outerjoin(
+ pg_catalog.pg_constraint,
+ sql.and_(
+ pg_catalog.pg_class.c.oid
+ == pg_catalog.pg_constraint.c.conrelid,
+ pg_catalog.pg_constraint.c.contype == "c",
+ ),
+ )
+ .where(self._pg_class_relkind_condition(relkinds))
+ )
+ query = self._pg_class_filter_scope_schema(query, schema, scope)
+ if has_filter_names:
+ query = query.where(
+ pg_catalog.pg_class.c.relname.in_(bindparam("filter_names"))
+ )
+ return query
- ret = []
- for name, src in c:
+ def get_multi_check_constraints(
+ self, connection, schema, filter_names, scope, kind, **kw
+ ):
+ has_filter_names, params = self._prepare_filter_names(filter_names)
+ query = self._check_constraint_query(
+ schema, has_filter_names, scope, kind
+ )
+ result = connection.execute(query, params)
+
+ check_constraints = defaultdict(list)
+ default = ReflectionDefaults.check_constraints
+ for table_name, check_name, src in result:
+ # only two cases for check_name and src: both null or both defined
+ if check_name is None and src is None:
+ check_constraints[(schema, table_name)] = default()
+ continue
# samples:
# "CHECK (((a > 1) AND (a < 5)))"
# "CHECK (((a = 1) OR ((a > 2) AND (a < 5))))"
@@ -4424,84 +4258,118 @@ class PGDialect(default.DefaultDialect):
sqltext = re.compile(
r"^[\s\n]*\((.+)\)[\s\n]*$", flags=re.DOTALL
).sub(r"\1", m.group(1))
- entry = {"name": name, "sqltext": sqltext}
+ entry = {"name": check_name, "sqltext": sqltext}
if m and m.group(2):
entry["dialect_options"] = {"not_valid": True}
- ret.append(entry)
- return ret
-
- def _load_enums(self, connection, schema=None):
- schema = schema or self.default_schema_name
- if not self.supports_native_enum:
- return {}
-
- # Load data types for enums:
- SQL_ENUMS = """
- SELECT t.typname as "name",
- -- no enum defaults in 8.4 at least
- -- t.typdefault as "default",
- pg_catalog.pg_type_is_visible(t.oid) as "visible",
- n.nspname as "schema",
- e.enumlabel as "label"
- FROM pg_catalog.pg_type t
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
- LEFT JOIN pg_catalog.pg_enum e ON t.oid = e.enumtypid
- WHERE t.typtype = 'e'
- """
+ check_constraints[(schema, table_name)].append(entry)
+ return check_constraints.items()
- if schema != "*":
- SQL_ENUMS += "AND n.nspname = :schema "
+ @lru_cache()
+ def _enum_query(self, schema):
+ lbl_sq = (
+ select(
+ pg_catalog.pg_enum.c.enumtypid, pg_catalog.pg_enum.c.enumlabel
+ )
+ .order_by(
+ pg_catalog.pg_enum.c.enumtypid,
+ pg_catalog.pg_enum.c.enumsortorder,
+ )
+ .subquery("lbl")
+ )
- # e.oid gives us label order within an enum
- SQL_ENUMS += 'ORDER BY "schema", "name", e.oid'
+ lbl_agg_sq = (
+ select(
+ lbl_sq.c.enumtypid,
+ sql.func.array_agg(lbl_sq.c.enumlabel).label("labels"),
+ )
+ .group_by(lbl_sq.c.enumtypid)
+ .subquery("lbl_agg")
+ )
- s = sql.text(SQL_ENUMS).columns(
- attname=sqltypes.Unicode, label=sqltypes.Unicode
+ query = (
+ select(
+ pg_catalog.pg_type.c.typname.label("name"),
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label(
+ "visible"
+ ),
+ pg_catalog.pg_namespace.c.nspname.label("schema"),
+ lbl_agg_sq.c.labels.label("labels"),
+ )
+ .join(
+ pg_catalog.pg_namespace,
+ pg_catalog.pg_namespace.c.oid
+ == pg_catalog.pg_type.c.typnamespace,
+ )
+ .outerjoin(
+ lbl_agg_sq, pg_catalog.pg_type.c.oid == lbl_agg_sq.c.enumtypid
+ )
+ .where(pg_catalog.pg_type.c.typtype == "e")
+ .order_by(
+ pg_catalog.pg_namespace.c.nspname, pg_catalog.pg_type.c.typname
+ )
)
- if schema != "*":
- s = s.bindparams(schema=schema)
+ if schema is None:
+ query = query.where(
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid),
+ # ignore pg_catalog schema
+ pg_catalog.pg_namespace.c.nspname != "pg_catalog",
+ )
+ elif schema != "*":
+ query = query.where(pg_catalog.pg_namespace.c.nspname == schema)
+ return query
+
+ @reflection.cache
+ def _load_enums(self, connection, schema=None, **kw):
+ if not self.supports_native_enum:
+ return []
- c = connection.execute(s)
+ result = connection.execute(self._enum_query(schema))
enums = []
- enum_by_name = {}
- for enum in c.fetchall():
- key = (enum.schema, enum.name)
- if key in enum_by_name:
- enum_by_name[key]["labels"].append(enum.label)
- else:
- enum_by_name[key] = enum_rec = {
- "name": enum.name,
- "schema": enum.schema,
- "visible": enum.visible,
- "labels": [],
+ for name, visible, schema, labels in result:
+ enums.append(
+ {
+ "name": name,
+ "schema": schema,
+ "visible": visible,
+ "labels": [] if labels is None else labels,
}
- if enum.label is not None:
- enum_rec["labels"].append(enum.label)
- enums.append(enum_rec)
+ )
return enums
- def _load_domains(self, connection):
- # Load data types for domains:
- SQL_DOMAINS = """
- SELECT t.typname as "name",
- pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype",
- not t.typnotnull as "nullable",
- t.typdefault as "default",
- pg_catalog.pg_type_is_visible(t.oid) as "visible",
- n.nspname as "schema"
- FROM pg_catalog.pg_type t
- LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace
- WHERE t.typtype = 'd'
- """
+ @util.memoized_property
+ def _domain_query(self):
+ return (
+ select(
+ pg_catalog.pg_type.c.typname.label("name"),
+ pg_catalog.format_type(
+ pg_catalog.pg_type.c.typbasetype,
+ pg_catalog.pg_type.c.typtypmod,
+ ).label("attype"),
+ (~pg_catalog.pg_type.c.typnotnull).label("nullable"),
+ pg_catalog.pg_type.c.typdefault.label("default"),
+ pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label(
+ "visible"
+ ),
+ pg_catalog.pg_namespace.c.nspname.label("schema"),
+ )
+ .join(
+ pg_catalog.pg_namespace,
+ pg_catalog.pg_namespace.c.oid
+ == pg_catalog.pg_type.c.typnamespace,
+ )
+ .where(pg_catalog.pg_type.c.typtype == "d")
+ )
- s = sql.text(SQL_DOMAINS)
- c = connection.execution_options(future_result=True).execute(s)
+ @reflection.cache
+ def _load_domains(self, connection, **kw):
+ # Load data types for domains:
+ result = connection.execute(self._domain_query)
domains = {}
- for domain in c.mappings():
+ for domain in result.mappings():
domain = domain
# strip (30) from character varying(30)
attype = re.search(r"([^\(]+)", domain["attype"]).group(1)
diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py
index 6cb97ece4..ce9a3bb6c 100644
--- a/lib/sqlalchemy/dialects/postgresql/pg8000.py
+++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py
@@ -107,6 +107,8 @@ from .base import PGIdentifierPreparer
from .json import JSON
from .json import JSONB
from .json import JSONPathType
+from .pg_catalog import _SpaceVector
+from .pg_catalog import OIDVECTOR
from ... import exc
from ... import util
from ...engine import processors
@@ -245,6 +247,10 @@ class _PGARRAY(PGARRAY):
render_bind_cast = True
+class _PGOIDVECTOR(_SpaceVector, OIDVECTOR):
+ pass
+
+
_server_side_id = util.counter()
@@ -376,6 +382,7 @@ class PGDialect_pg8000(PGDialect):
sqltypes.BigInteger: _PGBigInteger,
sqltypes.Enum: _PGEnum,
sqltypes.ARRAY: _PGARRAY,
+ OIDVECTOR: _PGOIDVECTOR,
},
)
diff --git a/lib/sqlalchemy/dialects/postgresql/pg_catalog.py b/lib/sqlalchemy/dialects/postgresql/pg_catalog.py
new file mode 100644
index 000000000..a77e7ccf6
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/pg_catalog.py
@@ -0,0 +1,292 @@
+# postgresql/pg_catalog.py
+# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+from .array import ARRAY
+from .types import OID
+from .types import REGCLASS
+from ... import Column
+from ... import func
+from ... import MetaData
+from ... import Table
+from ...types import BigInteger
+from ...types import Boolean
+from ...types import CHAR
+from ...types import Float
+from ...types import Integer
+from ...types import SmallInteger
+from ...types import String
+from ...types import Text
+from ...types import TypeDecorator
+
+
+# types
+class NAME(TypeDecorator):
+ impl = String(64, collation="C")
+ cache_ok = True
+
+
+class PG_NODE_TREE(TypeDecorator):
+ impl = Text(collation="C")
+ cache_ok = True
+
+
+class INT2VECTOR(TypeDecorator):
+ impl = ARRAY(SmallInteger)
+ cache_ok = True
+
+
+class OIDVECTOR(TypeDecorator):
+ impl = ARRAY(OID)
+ cache_ok = True
+
+
+class _SpaceVector:
+ def result_processor(self, dialect, coltype):
+ def process(value):
+ if value is None:
+ return value
+ return [int(p) for p in value.split(" ")]
+
+ return process
+
+
+REGPROC = REGCLASS # seems an alias
+
+# functions
+_pg_cat = func.pg_catalog
+quote_ident = _pg_cat.quote_ident
+pg_table_is_visible = _pg_cat.pg_table_is_visible
+pg_type_is_visible = _pg_cat.pg_type_is_visible
+pg_get_viewdef = _pg_cat.pg_get_viewdef
+pg_get_serial_sequence = _pg_cat.pg_get_serial_sequence
+format_type = _pg_cat.format_type
+pg_get_expr = _pg_cat.pg_get_expr
+pg_get_constraintdef = _pg_cat.pg_get_constraintdef
+
+# constants
+RELKINDS_TABLE_NO_FOREIGN = ("r", "p")
+RELKINDS_TABLE = RELKINDS_TABLE_NO_FOREIGN + ("f",)
+RELKINDS_VIEW = ("v",)
+RELKINDS_MAT_VIEW = ("m",)
+RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW
+
+# tables
+pg_catalog_meta = MetaData()
+
+pg_namespace = Table(
+ "pg_namespace",
+ pg_catalog_meta,
+ Column("oid", OID),
+ Column("nspname", NAME),
+ Column("nspowner", OID),
+ schema="pg_catalog",
+)
+
+pg_class = Table(
+ "pg_class",
+ pg_catalog_meta,
+ Column("oid", OID, info={"server_version": (9, 3)}),
+ Column("relname", NAME),
+ Column("relnamespace", OID),
+ Column("reltype", OID),
+ Column("reloftype", OID),
+ Column("relowner", OID),
+ Column("relam", OID),
+ Column("relfilenode", OID),
+ Column("reltablespace", OID),
+ Column("relpages", Integer),
+ Column("reltuples", Float),
+ Column("relallvisible", Integer, info={"server_version": (9, 2)}),
+ Column("reltoastrelid", OID),
+ Column("relhasindex", Boolean),
+ Column("relisshared", Boolean),
+ Column("relpersistence", CHAR, info={"server_version": (9, 1)}),
+ Column("relkind", CHAR),
+ Column("relnatts", SmallInteger),
+ Column("relchecks", SmallInteger),
+ Column("relhasrules", Boolean),
+ Column("relhastriggers", Boolean),
+ Column("relhassubclass", Boolean),
+ Column("relrowsecurity", Boolean),
+ Column("relforcerowsecurity", Boolean, info={"server_version": (9, 5)}),
+ Column("relispopulated", Boolean, info={"server_version": (9, 3)}),
+ Column("relreplident", CHAR, info={"server_version": (9, 4)}),
+ Column("relispartition", Boolean, info={"server_version": (10,)}),
+ Column("relrewrite", OID, info={"server_version": (11,)}),
+ Column("reloptions", ARRAY(Text)),
+ schema="pg_catalog",
+)
+
+pg_type = Table(
+ "pg_type",
+ pg_catalog_meta,
+ Column("oid", OID, info={"server_version": (9, 3)}),
+ Column("typname", NAME),
+ Column("typnamespace", OID),
+ Column("typowner", OID),
+ Column("typlen", SmallInteger),
+ Column("typbyval", Boolean),
+ Column("typtype", CHAR),
+ Column("typcategory", CHAR),
+ Column("typispreferred", Boolean),
+ Column("typisdefined", Boolean),
+ Column("typdelim", CHAR),
+ Column("typrelid", OID),
+ Column("typelem", OID),
+ Column("typarray", OID),
+ Column("typinput", REGPROC),
+ Column("typoutput", REGPROC),
+ Column("typreceive", REGPROC),
+ Column("typsend", REGPROC),
+ Column("typmodin", REGPROC),
+ Column("typmodout", REGPROC),
+ Column("typanalyze", REGPROC),
+ Column("typalign", CHAR),
+ Column("typstorage", CHAR),
+ Column("typnotnull", Boolean),
+ Column("typbasetype", OID),
+ Column("typtypmod", Integer),
+ Column("typndims", Integer),
+ Column("typcollation", OID, info={"server_version": (9, 1)}),
+ Column("typdefault", Text),
+ schema="pg_catalog",
+)
+
+pg_index = Table(
+ "pg_index",
+ pg_catalog_meta,
+ Column("indexrelid", OID),
+ Column("indrelid", OID),
+ Column("indnatts", SmallInteger),
+ Column("indnkeyatts", SmallInteger, info={"server_version": (11,)}),
+ Column("indisunique", Boolean),
+ Column("indisprimary", Boolean),
+ Column("indisexclusion", Boolean, info={"server_version": (9, 1)}),
+ Column("indimmediate", Boolean),
+ Column("indisclustered", Boolean),
+ Column("indisvalid", Boolean),
+ Column("indcheckxmin", Boolean),
+ Column("indisready", Boolean),
+ Column("indislive", Boolean, info={"server_version": (9, 3)}), # 9.3
+ Column("indisreplident", Boolean),
+ Column("indkey", INT2VECTOR),
+ Column("indcollation", OIDVECTOR, info={"server_version": (9, 1)}), # 9.1
+ Column("indclass", OIDVECTOR),
+ Column("indoption", INT2VECTOR),
+ Column("indexprs", PG_NODE_TREE),
+ Column("indpred", PG_NODE_TREE),
+ schema="pg_catalog",
+)
+
+pg_attribute = Table(
+ "pg_attribute",
+ pg_catalog_meta,
+ Column("attrelid", OID),
+ Column("attname", NAME),
+ Column("atttypid", OID),
+ Column("attstattarget", Integer),
+ Column("attlen", SmallInteger),
+ Column("attnum", SmallInteger),
+ Column("attndims", Integer),
+ Column("attcacheoff", Integer),
+ Column("atttypmod", Integer),
+ Column("attbyval", Boolean),
+ Column("attstorage", CHAR),
+ Column("attalign", CHAR),
+ Column("attnotnull", Boolean),
+ Column("atthasdef", Boolean),
+ Column("atthasmissing", Boolean, info={"server_version": (11,)}),
+ Column("attidentity", CHAR, info={"server_version": (10,)}),
+ Column("attgenerated", CHAR, info={"server_version": (12,)}),
+ Column("attisdropped", Boolean),
+ Column("attislocal", Boolean),
+ Column("attinhcount", Integer),
+ Column("attcollation", OID, info={"server_version": (9, 1)}),
+ schema="pg_catalog",
+)
+
+pg_constraint = Table(
+ "pg_constraint",
+ pg_catalog_meta,
+ Column("oid", OID), # 9.3
+ Column("conname", NAME),
+ Column("connamespace", OID),
+ Column("contype", CHAR),
+ Column("condeferrable", Boolean),
+ Column("condeferred", Boolean),
+ Column("convalidated", Boolean, info={"server_version": (9, 1)}),
+ Column("conrelid", OID),
+ Column("contypid", OID),
+ Column("conindid", OID),
+ Column("conparentid", OID, info={"server_version": (11,)}),
+ Column("confrelid", OID),
+ Column("confupdtype", CHAR),
+ Column("confdeltype", CHAR),
+ Column("confmatchtype", CHAR),
+ Column("conislocal", Boolean),
+ Column("coninhcount", Integer),
+ Column("connoinherit", Boolean, info={"server_version": (9, 2)}),
+ Column("conkey", ARRAY(SmallInteger)),
+ Column("confkey", ARRAY(SmallInteger)),
+ schema="pg_catalog",
+)
+
+pg_sequence = Table(
+ "pg_sequence",
+ pg_catalog_meta,
+ Column("seqrelid", OID),
+ Column("seqtypid", OID),
+ Column("seqstart", BigInteger),
+ Column("seqincrement", BigInteger),
+ Column("seqmax", BigInteger),
+ Column("seqmin", BigInteger),
+ Column("seqcache", BigInteger),
+ Column("seqcycle", Boolean),
+ schema="pg_catalog",
+ info={"server_version": (10,)},
+)
+
+pg_attrdef = Table(
+ "pg_attrdef",
+ pg_catalog_meta,
+ Column("oid", OID, info={"server_version": (9, 3)}),
+ Column("adrelid", OID),
+ Column("adnum", SmallInteger),
+ Column("adbin", PG_NODE_TREE),
+ schema="pg_catalog",
+)
+
+pg_description = Table(
+ "pg_description",
+ pg_catalog_meta,
+ Column("objoid", OID),
+ Column("classoid", OID),
+ Column("objsubid", Integer),
+ Column("description", Text(collation="C")),
+ schema="pg_catalog",
+)
+
+pg_enum = Table(
+ "pg_enum",
+ pg_catalog_meta,
+ Column("oid", OID, info={"server_version": (9, 3)}),
+ Column("enumtypid", OID),
+ Column("enumsortorder", Float(), info={"server_version": (9, 1)}),
+ Column("enumlabel", NAME),
+ schema="pg_catalog",
+)
+
+pg_am = Table(
+ "pg_am",
+ pg_catalog_meta,
+ Column("oid", OID, info={"server_version": (9, 3)}),
+ Column("amname", NAME),
+ Column("amhandler", REGPROC, info={"server_version": (9, 6)}),
+ Column("amtype", CHAR, info={"server_version": (9, 6)}),
+ schema="pg_catalog",
+)
diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py
new file mode 100644
index 000000000..55735953b
--- /dev/null
+++ b/lib/sqlalchemy/dialects/postgresql/types.py
@@ -0,0 +1,485 @@
+# Copyright (C) 2013-2022 the SQLAlchemy authors and contributors
+# <see AUTHORS file>
+#
+# This module is part of SQLAlchemy and is released under
+# the MIT License: https://www.opensource.org/licenses/mit-license.php
+# mypy: ignore-errors
+
+import datetime as dt
+from typing import Any
+
+from ... import schema
+from ... import util
+from ...sql import sqltypes
+from ...sql.ddl import InvokeDDLBase
+
+
+_DECIMAL_TYPES = (1231, 1700)
+_FLOAT_TYPES = (700, 701, 1021, 1022)
+_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016)
+
+
+class PGUuid(sqltypes.UUID):
+ render_bind_cast = True
+ render_literal_cast = True
+
+
+class BYTEA(sqltypes.LargeBinary[bytes]):
+ __visit_name__ = "BYTEA"
+
+
+class INET(sqltypes.TypeEngine[str]):
+ __visit_name__ = "INET"
+
+
+PGInet = INET
+
+
+class CIDR(sqltypes.TypeEngine[str]):
+ __visit_name__ = "CIDR"
+
+
+PGCidr = CIDR
+
+
+class MACADDR(sqltypes.TypeEngine[str]):
+ __visit_name__ = "MACADDR"
+
+
+PGMacAddr = MACADDR
+
+
+class MONEY(sqltypes.TypeEngine[str]):
+
+ r"""Provide the PostgreSQL MONEY type.
+
+ Depending on driver, result rows using this type may return a
+ string value which includes currency symbols.
+
+ For this reason, it may be preferable to provide conversion to a
+ numerically-based currency datatype using :class:`_types.TypeDecorator`::
+
+ import re
+ import decimal
+ from sqlalchemy import TypeDecorator
+
+ class NumericMoney(TypeDecorator):
+ impl = MONEY
+
+ def process_result_value(self, value: Any, dialect: Any) -> None:
+ if value is not None:
+ # adjust this for the currency and numeric
+ m = re.match(r"\$([\d.]+)", value)
+ if m:
+ value = decimal.Decimal(m.group(1))
+ return value
+
+ Alternatively, the conversion may be applied as a CAST using
+ the :meth:`_types.TypeDecorator.column_expression` method as follows::
+
+ import decimal
+ from sqlalchemy import cast
+ from sqlalchemy import TypeDecorator
+
+ class NumericMoney(TypeDecorator):
+ impl = MONEY
+
+ def column_expression(self, column: Any):
+ return cast(column, Numeric())
+
+ .. versionadded:: 1.2
+
+ """
+
+ __visit_name__ = "MONEY"
+
+
+class OID(sqltypes.TypeEngine[int]):
+
+ """Provide the PostgreSQL OID type.
+
+ .. versionadded:: 0.9.5
+
+ """
+
+ __visit_name__ = "OID"
+
+
+class REGCLASS(sqltypes.TypeEngine[str]):
+
+ """Provide the PostgreSQL REGCLASS type.
+
+ .. versionadded:: 1.2.7
+
+ """
+
+ __visit_name__ = "REGCLASS"
+
+
+class TIMESTAMP(sqltypes.TIMESTAMP):
+ def __init__(self, timezone=False, precision=None):
+ super(TIMESTAMP, self).__init__(timezone=timezone)
+ self.precision = precision
+
+
+class TIME(sqltypes.TIME):
+ def __init__(self, timezone=False, precision=None):
+ super(TIME, self).__init__(timezone=timezone)
+ self.precision = precision
+
+
+class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
+
+ """PostgreSQL INTERVAL type."""
+
+ __visit_name__ = "INTERVAL"
+ native = True
+
+ def __init__(self, precision=None, fields=None):
+ """Construct an INTERVAL.
+
+ :param precision: optional integer precision value
+ :param fields: string fields specifier. allows storage of fields
+ to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``,
+ etc.
+
+ .. versionadded:: 1.2
+
+ """
+ self.precision = precision
+ self.fields = fields
+
+ @classmethod
+ def adapt_emulated_to_native(cls, interval, **kw):
+ return INTERVAL(precision=interval.second_precision)
+
+ @property
+ def _type_affinity(self):
+ return sqltypes.Interval
+
+ def as_generic(self, allow_nulltype=False):
+ return sqltypes.Interval(native=True, second_precision=self.precision)
+
+ @property
+ def python_type(self):
+ return dt.timedelta
+
+
+PGInterval = INTERVAL
+
+
+class BIT(sqltypes.TypeEngine[int]):
+ __visit_name__ = "BIT"
+
+ def __init__(self, length=None, varying=False):
+ if not varying:
+ # BIT without VARYING defaults to length 1
+ self.length = length or 1
+ else:
+ # but BIT VARYING can be unlimited-length, so no default
+ self.length = length
+ self.varying = varying
+
+
+PGBit = BIT
+
+
+class TSVECTOR(sqltypes.TypeEngine[Any]):
+
+ """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
+ text search type TSVECTOR.
+
+ It can be used to do full text queries on natural language
+ documents.
+
+ .. versionadded:: 0.9.0
+
+ .. seealso::
+
+ :ref:`postgresql_match`
+
+ """
+
+ __visit_name__ = "TSVECTOR"
+
+
+class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum):
+
+ """PostgreSQL ENUM type.
+
+ This is a subclass of :class:`_types.Enum` which includes
+ support for PG's ``CREATE TYPE`` and ``DROP TYPE``.
+
+ When the builtin type :class:`_types.Enum` is used and the
+ :paramref:`.Enum.native_enum` flag is left at its default of
+ True, the PostgreSQL backend will use a :class:`_postgresql.ENUM`
+ type as the implementation, so the special create/drop rules
+ will be used.
+
+ The create/drop behavior of ENUM is necessarily intricate, due to the
+ awkward relationship the ENUM type has in relationship to the
+ parent table, in that it may be "owned" by just a single table, or
+ may be shared among many tables.
+
+ When using :class:`_types.Enum` or :class:`_postgresql.ENUM`
+ in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted
+ corresponding to when the :meth:`_schema.Table.create` and
+ :meth:`_schema.Table.drop`
+ methods are called::
+
+ table = Table('sometable', metadata,
+ Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
+ )
+
+ table.create(engine) # will emit CREATE ENUM and CREATE TABLE
+ table.drop(engine) # will emit DROP TABLE and DROP ENUM
+
+ To use a common enumerated type between multiple tables, the best
+ practice is to declare the :class:`_types.Enum` or
+ :class:`_postgresql.ENUM` independently, and associate it with the
+ :class:`_schema.MetaData` object itself::
+
+ my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
+
+ t1 = Table('sometable_one', metadata,
+ Column('some_enum', myenum)
+ )
+
+ t2 = Table('sometable_two', metadata,
+ Column('some_enum', myenum)
+ )
+
+ When this pattern is used, care must still be taken at the level
+ of individual table creates. Emitting CREATE TABLE without also
+ specifying ``checkfirst=True`` will still cause issues::
+
+ t1.create(engine) # will fail: no such type 'myenum'
+
+ If we specify ``checkfirst=True``, the individual table-level create
+ operation will check for the ``ENUM`` and create if not exists::
+
+ # will check if enum exists, and emit CREATE TYPE if not
+ t1.create(engine, checkfirst=True)
+
+ When using a metadata-level ENUM type, the type will always be created
+ and dropped if either the metadata-wide create/drop is called::
+
+ metadata.create_all(engine) # will emit CREATE TYPE
+ metadata.drop_all(engine) # will emit DROP TYPE
+
+ The type can also be created and dropped directly::
+
+ my_enum.create(engine)
+ my_enum.drop(engine)
+
+ .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type
+ now behaves more strictly with regards to CREATE/DROP. A metadata-level
+ ENUM type will only be created and dropped at the metadata level,
+ not the table level, with the exception of
+ ``table.create(checkfirst=True)``.
+ The ``table.drop()`` call will now emit a DROP TYPE for a table-level
+ enumerated type.
+
+ """
+
+ native_enum = True
+
+ def __init__(self, *enums, **kw):
+ """Construct an :class:`_postgresql.ENUM`.
+
+ Arguments are the same as that of
+ :class:`_types.Enum`, but also including
+ the following parameters.
+
+ :param create_type: Defaults to True.
+ Indicates that ``CREATE TYPE`` should be
+ emitted, after optionally checking for the
+ presence of the type, when the parent
+ table is being created; and additionally
+ that ``DROP TYPE`` is called when the table
+ is dropped. When ``False``, no check
+ will be performed and no ``CREATE TYPE``
+ or ``DROP TYPE`` is emitted, unless
+ :meth:`~.postgresql.ENUM.create`
+ or :meth:`~.postgresql.ENUM.drop`
+ are called directly.
+ Setting to ``False`` is helpful
+ when invoking a creation scheme to a SQL file
+ without access to the actual database -
+ the :meth:`~.postgresql.ENUM.create` and
+ :meth:`~.postgresql.ENUM.drop` methods can
+ be used to emit SQL to a target bind.
+
+ """
+ native_enum = kw.pop("native_enum", None)
+ if native_enum is False:
+ util.warn(
+ "the native_enum flag does not apply to the "
+ "sqlalchemy.dialects.postgresql.ENUM datatype; this type "
+ "always refers to ENUM. Use sqlalchemy.types.Enum for "
+ "non-native enum."
+ )
+ self.create_type = kw.pop("create_type", True)
+ super(ENUM, self).__init__(*enums, **kw)
+
+ @classmethod
+ def adapt_emulated_to_native(cls, impl, **kw):
+ """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain
+ :class:`.Enum`.
+
+ """
+ kw.setdefault("validate_strings", impl.validate_strings)
+ kw.setdefault("name", impl.name)
+ kw.setdefault("schema", impl.schema)
+ kw.setdefault("inherit_schema", impl.inherit_schema)
+ kw.setdefault("metadata", impl.metadata)
+ kw.setdefault("_create_events", False)
+ kw.setdefault("values_callable", impl.values_callable)
+ kw.setdefault("omit_aliases", impl._omit_aliases)
+ return cls(**kw)
+
+ def create(self, bind=None, checkfirst=True):
+ """Emit ``CREATE TYPE`` for this
+ :class:`_postgresql.ENUM`.
+
+ If the underlying dialect does not support
+ PostgreSQL CREATE TYPE, no action is taken.
+
+ :param bind: a connectable :class:`_engine.Engine`,
+ :class:`_engine.Connection`, or similar object to emit
+ SQL.
+ :param checkfirst: if ``True``, a query against
+ the PG catalog will be first performed to see
+ if the type does not exist already before
+ creating.
+
+ """
+ if not bind.dialect.supports_native_enum:
+ return
+
+ bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst)
+
+ def drop(self, bind=None, checkfirst=True):
+ """Emit ``DROP TYPE`` for this
+ :class:`_postgresql.ENUM`.
+
+ If the underlying dialect does not support
+ PostgreSQL DROP TYPE, no action is taken.
+
+ :param bind: a connectable :class:`_engine.Engine`,
+ :class:`_engine.Connection`, or similar object to emit
+ SQL.
+ :param checkfirst: if ``True``, a query against
+ the PG catalog will be first performed to see
+ if the type actually exists before dropping.
+
+ """
+ if not bind.dialect.supports_native_enum:
+ return
+
+ bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst)
+
+ class EnumGenerator(InvokeDDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+ super(ENUM.EnumGenerator, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+
+ def _can_create_enum(self, enum):
+ if not self.checkfirst:
+ return True
+
+ effective_schema = self.connection.schema_for_object(enum)
+
+ return not self.connection.dialect.has_type(
+ self.connection, enum.name, schema=effective_schema
+ )
+
+ def visit_enum(self, enum):
+ if not self._can_create_enum(enum):
+ return
+
+ self.connection.execute(CreateEnumType(enum))
+
+ class EnumDropper(InvokeDDLBase):
+ def __init__(self, dialect, connection, checkfirst=False, **kwargs):
+ super(ENUM.EnumDropper, self).__init__(connection, **kwargs)
+ self.checkfirst = checkfirst
+
+ def _can_drop_enum(self, enum):
+ if not self.checkfirst:
+ return True
+
+ effective_schema = self.connection.schema_for_object(enum)
+
+ return self.connection.dialect.has_type(
+ self.connection, enum.name, schema=effective_schema
+ )
+
+ def visit_enum(self, enum):
+ if not self._can_drop_enum(enum):
+ return
+
+ self.connection.execute(DropEnumType(enum))
+
+ def get_dbapi_type(self, dbapi):
+ """dont return dbapi.STRING for ENUM in PostgreSQL, since that's
+ a different type"""
+
+ return None
+
+ def _check_for_name_in_memos(self, checkfirst, kw):
+ """Look in the 'ddl runner' for 'memos', then
+ note our name in that collection.
+
+ This to ensure a particular named enum is operated
+ upon only once within any kind of create/drop
+ sequence without relying upon "checkfirst".
+
+ """
+ if not self.create_type:
+ return True
+ if "_ddl_runner" in kw:
+ ddl_runner = kw["_ddl_runner"]
+ if "_pg_enums" in ddl_runner.memo:
+ pg_enums = ddl_runner.memo["_pg_enums"]
+ else:
+ pg_enums = ddl_runner.memo["_pg_enums"] = set()
+ present = (self.schema, self.name) in pg_enums
+ pg_enums.add((self.schema, self.name))
+ return present
+ else:
+ return False
+
+ def _on_table_create(self, target, bind, checkfirst=False, **kw):
+ if (
+ checkfirst
+ or (
+ not self.metadata
+ and not kw.get("_is_metadata_operation", False)
+ )
+ ) and not self._check_for_name_in_memos(checkfirst, kw):
+ self.create(bind=bind, checkfirst=checkfirst)
+
+ def _on_table_drop(self, target, bind, checkfirst=False, **kw):
+ if (
+ not self.metadata
+ and not kw.get("_is_metadata_operation", False)
+ and not self._check_for_name_in_memos(checkfirst, kw)
+ ):
+ self.drop(bind=bind, checkfirst=checkfirst)
+
+ def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
+ if not self._check_for_name_in_memos(checkfirst, kw):
+ self.create(bind=bind, checkfirst=checkfirst)
+
+ def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
+ if not self._check_for_name_in_memos(checkfirst, kw):
+ self.drop(bind=bind, checkfirst=checkfirst)
+
+
+class CreateEnumType(schema._CreateDropBase):
+ __visit_name__ = "create_enum_type"
+
+
+class DropEnumType(schema._CreateDropBase):
+ __visit_name__ = "drop_enum_type"
diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py
index fdcd1340b..22f003e38 100644
--- a/lib/sqlalchemy/dialects/sqlite/base.py
+++ b/lib/sqlalchemy/dialects/sqlite/base.py
@@ -867,6 +867,7 @@ from ... import util
from ...engine import default
from ...engine import processors
from ...engine import reflection
+from ...engine.reflection import ReflectionDefaults
from ...sql import coercions
from ...sql import ColumnElement
from ...sql import compiler
@@ -2053,28 +2054,27 @@ class SQLiteDialect(default.DefaultDialect):
return [db[1] for db in dl if db[1] != "temp"]
- @reflection.cache
- def get_table_names(self, connection, schema=None, **kw):
+ def _format_schema(self, schema, table_name):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
- master = "%s.sqlite_master" % qschema
+ name = f"{qschema}.{table_name}"
else:
- master = "sqlite_master"
- s = ("SELECT name FROM %s " "WHERE type='table' ORDER BY name") % (
- master,
- )
- rs = connection.exec_driver_sql(s)
- return [row[0] for row in rs]
+ name = table_name
+ return name
@reflection.cache
- def get_temp_table_names(self, connection, **kw):
- s = (
- "SELECT name FROM sqlite_temp_master "
- "WHERE type='table' ORDER BY name "
- )
- rs = connection.exec_driver_sql(s)
+ def get_table_names(self, connection, schema=None, **kw):
+ main = self._format_schema(schema, "sqlite_master")
+ s = f"SELECT name FROM {main} WHERE type='table' ORDER BY name"
+ names = connection.exec_driver_sql(s).scalars().all()
+ return names
- return [row[0] for row in rs]
+ @reflection.cache
+ def get_temp_table_names(self, connection, **kw):
+ main = "sqlite_temp_master"
+ s = f"SELECT name FROM {main} WHERE type='table' ORDER BY name"
+ names = connection.exec_driver_sql(s).scalars().all()
+ return names
@reflection.cache
def get_temp_view_names(self, connection, **kw):
@@ -2082,11 +2082,11 @@ class SQLiteDialect(default.DefaultDialect):
"SELECT name FROM sqlite_temp_master "
"WHERE type='view' ORDER BY name "
)
- rs = connection.exec_driver_sql(s)
-
- return [row[0] for row in rs]
+ names = connection.exec_driver_sql(s).scalars().all()
+ return names
- def has_table(self, connection, table_name, schema=None):
+ @reflection.cache
+ def has_table(self, connection, table_name, schema=None, **kw):
self._ensure_has_table_connection(connection)
info = self._get_table_pragma(
@@ -2099,23 +2099,16 @@ class SQLiteDialect(default.DefaultDialect):
@reflection.cache
def get_view_names(self, connection, schema=None, **kw):
- if schema is not None:
- qschema = self.identifier_preparer.quote_identifier(schema)
- master = "%s.sqlite_master" % qschema
- else:
- master = "sqlite_master"
- s = ("SELECT name FROM %s " "WHERE type='view' ORDER BY name") % (
- master,
- )
- rs = connection.exec_driver_sql(s)
-
- return [row[0] for row in rs]
+ main = self._format_schema(schema, "sqlite_master")
+ s = f"SELECT name FROM {main} WHERE type='view' ORDER BY name"
+ names = connection.exec_driver_sql(s).scalars().all()
+ return names
@reflection.cache
def get_view_definition(self, connection, view_name, schema=None, **kw):
if schema is not None:
qschema = self.identifier_preparer.quote_identifier(schema)
- master = "%s.sqlite_master" % qschema
+ master = f"{qschema}.sqlite_master"
s = ("SELECT sql FROM %s WHERE name = ? AND type='view'") % (
master,
)
@@ -2140,6 +2133,10 @@ class SQLiteDialect(default.DefaultDialect):
result = rs.fetchall()
if result:
return result[0].sql
+ else:
+ raise exc.NoSuchTableError(
+ f"{schema}.{view_name}" if schema else view_name
+ )
@reflection.cache
def get_columns(self, connection, table_name, schema=None, **kw):
@@ -2186,7 +2183,14 @@ class SQLiteDialect(default.DefaultDialect):
tablesql,
)
)
- return columns
+ if columns:
+ return columns
+ elif not self.has_table(connection, table_name, schema):
+ raise exc.NoSuchTableError(
+ f"{schema}.{table_name}" if schema else table_name
+ )
+ else:
+ return ReflectionDefaults.columns()
def _get_column_info(
self,
@@ -2216,7 +2220,6 @@ class SQLiteDialect(default.DefaultDialect):
"type": coltype,
"nullable": nullable,
"default": default,
- "autoincrement": "auto",
"primary_key": primary_key,
}
if generated:
@@ -2295,13 +2298,16 @@ class SQLiteDialect(default.DefaultDialect):
constraint_name = result.group(1) if result else None
cols = self.get_columns(connection, table_name, schema, **kw)
+ # consider only pk columns. This also avoids sorting the cached
+ # value returned by get_columns
+ cols = [col for col in cols if col.get("primary_key", 0) > 0]
cols.sort(key=lambda col: col.get("primary_key"))
- pkeys = []
- for col in cols:
- if col["primary_key"]:
- pkeys.append(col["name"])
+ pkeys = [col["name"] for col in cols]
- return {"constrained_columns": pkeys, "name": constraint_name}
+ if pkeys:
+ return {"constrained_columns": pkeys, "name": constraint_name}
+ else:
+ return ReflectionDefaults.pk_constraint()
@reflection.cache
def get_foreign_keys(self, connection, table_name, schema=None, **kw):
@@ -2321,12 +2327,14 @@ class SQLiteDialect(default.DefaultDialect):
# original DDL. The referred columns of the foreign key
# constraint are therefore the primary key of the referred
# table.
- referred_pk = self.get_pk_constraint(
- connection, rtbl, schema=schema, **kw
- )
- # note that if table doesn't exist, we still get back a record,
- # just it has no columns in it
- referred_columns = referred_pk["constrained_columns"]
+ try:
+ referred_pk = self.get_pk_constraint(
+ connection, rtbl, schema=schema, **kw
+ )
+ referred_columns = referred_pk["constrained_columns"]
+ except exc.NoSuchTableError:
+ # ignore not existing parents
+ referred_columns = []
else:
# note we use this list only if this is the first column
# in the constraint. for subsequent columns we ignore the
@@ -2378,11 +2386,11 @@ class SQLiteDialect(default.DefaultDialect):
)
table_data = self._get_table_sql(connection, table_name, schema=schema)
- if table_data is None:
- # system tables, etc.
- return []
def parse_fks():
+ if table_data is None:
+ # system tables, etc.
+ return
FK_PATTERN = (
r"(?:CONSTRAINT (\w+) +)?"
r"FOREIGN KEY *\( *(.+?) *\) +"
@@ -2453,7 +2461,10 @@ class SQLiteDialect(default.DefaultDialect):
# use them as is as it's extremely difficult to parse inline
# constraints
fkeys.extend(keys_by_signature.values())
- return fkeys
+ if fkeys:
+ return fkeys
+ else:
+ return ReflectionDefaults.foreign_keys()
def _find_cols_in_sig(self, sig):
for match in re.finditer(r'(?:"(.+?)")|([a-z0-9_]+)', sig, re.I):
@@ -2480,12 +2491,11 @@ class SQLiteDialect(default.DefaultDialect):
table_data = self._get_table_sql(
connection, table_name, schema=schema, **kw
)
- if not table_data:
- return []
-
unique_constraints = []
def parse_uqs():
+ if table_data is None:
+ return
UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)'
INLINE_UNIQUE_PATTERN = (
r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?) '
@@ -2513,15 +2523,16 @@ class SQLiteDialect(default.DefaultDialect):
unique_constraints.append(parsed_constraint)
# NOTE: auto_index_by_sig might not be empty here,
# the PRIMARY KEY may have an entry.
- return unique_constraints
+ if unique_constraints:
+ return unique_constraints
+ else:
+ return ReflectionDefaults.unique_constraints()
@reflection.cache
def get_check_constraints(self, connection, table_name, schema=None, **kw):
table_data = self._get_table_sql(
connection, table_name, schema=schema, **kw
)
- if not table_data:
- return []
CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?" r"CHECK *\( *(.+) *\),? *"
check_constraints = []
@@ -2531,7 +2542,7 @@ class SQLiteDialect(default.DefaultDialect):
# necessarily makes assumptions as to how the CREATE TABLE
# was emitted.
- for match in re.finditer(CHECK_PATTERN, table_data, re.I):
+ for match in re.finditer(CHECK_PATTERN, table_data or "", re.I):
name = match.group(1)
if name:
@@ -2539,7 +2550,10 @@ class SQLiteDialect(default.DefaultDialect):
check_constraints.append({"sqltext": match.group(2), "name": name})
- return check_constraints
+ if check_constraints:
+ return check_constraints
+ else:
+ return ReflectionDefaults.check_constraints()
@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
@@ -2561,7 +2575,7 @@ class SQLiteDialect(default.DefaultDialect):
# loop thru unique indexes to get the column names.
for idx in list(indexes):
pragma_index = self._get_table_pragma(
- connection, "index_info", idx["name"]
+ connection, "index_info", idx["name"], schema=schema
)
for row in pragma_index:
@@ -2574,7 +2588,23 @@ class SQLiteDialect(default.DefaultDialect):
break
else:
idx["column_names"].append(row[2])
- return indexes
+ indexes.sort(key=lambda d: d["name"] or "~") # sort None as last
+ if indexes:
+ return indexes
+ elif not self.has_table(connection, table_name, schema):
+ raise exc.NoSuchTableError(
+ f"{schema}.{table_name}" if schema else table_name
+ )
+ else:
+ return ReflectionDefaults.indexes()
+
+ def _is_sys_table(self, table_name):
+ return table_name in {
+ "sqlite_schema",
+ "sqlite_master",
+ "sqlite_temp_schema",
+ "sqlite_temp_master",
+ }
@reflection.cache
def _get_table_sql(self, connection, table_name, schema=None, **kw):
@@ -2590,22 +2620,25 @@ class SQLiteDialect(default.DefaultDialect):
" (SELECT * FROM %(schema)ssqlite_master UNION ALL "
" SELECT * FROM %(schema)ssqlite_temp_master) "
"WHERE name = ? "
- "AND type = 'table'" % {"schema": schema_expr}
+ "AND type in ('table', 'view')" % {"schema": schema_expr}
)
rs = connection.exec_driver_sql(s, (table_name,))
except exc.DBAPIError:
s = (
"SELECT sql FROM %(schema)ssqlite_master "
"WHERE name = ? "
- "AND type = 'table'" % {"schema": schema_expr}
+ "AND type in ('table', 'view')" % {"schema": schema_expr}
)
rs = connection.exec_driver_sql(s, (table_name,))
- return rs.scalar()
+ value = rs.scalar()
+ if value is None and not self._is_sys_table(table_name):
+ raise exc.NoSuchTableError(f"{schema_expr}{table_name}")
+ return value
def _get_table_pragma(self, connection, pragma, table_name, schema=None):
quote = self.identifier_preparer.quote_identifier
if schema is not None:
- statements = ["PRAGMA %s." % quote(schema)]
+ statements = [f"PRAGMA {quote(schema)}."]
else:
# because PRAGMA looks in all attached databases if no schema
# given, need to specify "main" schema, however since we want
@@ -2615,7 +2648,7 @@ class SQLiteDialect(default.DefaultDialect):
qtable = quote(table_name)
for statement in statements:
- statement = "%s%s(%s)" % (statement, pragma, qtable)
+ statement = f"{statement}{pragma}({qtable})"
cursor = connection.exec_driver_sql(statement)
if not cursor._soft_closed:
# work around SQLite issue whereby cursor.description
diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py
index afba17075..77c2fea40 100644
--- a/lib/sqlalchemy/engine/__init__.py
+++ b/lib/sqlalchemy/engine/__init__.py
@@ -38,6 +38,8 @@ from .interfaces import ExecutionContext as ExecutionContext
from .interfaces import TypeCompiler as TypeCompiler
from .mock import create_mock_engine as create_mock_engine
from .reflection import Inspector as Inspector
+from .reflection import ObjectKind as ObjectKind
+from .reflection import ObjectScope as ObjectScope
from .result import ChunkedIteratorResult as ChunkedIteratorResult
from .result import FrozenResult as FrozenResult
from .result import IteratorResult as IteratorResult
diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py
index df35e7128..40af06252 100644
--- a/lib/sqlalchemy/engine/default.py
+++ b/lib/sqlalchemy/engine/default.py
@@ -45,6 +45,8 @@ from .interfaces import CacheStats
from .interfaces import DBAPICursor
from .interfaces import Dialect
from .interfaces import ExecutionContext
+from .reflection import ObjectKind
+from .reflection import ObjectScope
from .. import event
from .. import exc
from .. import pool
@@ -508,15 +510,22 @@ class DefaultDialect(Dialect):
"""
return type_api.adapt_type(typeobj, self.colspecs)
- def has_index(self, connection, table_name, index_name, schema=None):
- if not self.has_table(connection, table_name, schema=schema):
+ def has_index(self, connection, table_name, index_name, schema=None, **kw):
+ if not self.has_table(connection, table_name, schema=schema, **kw):
return False
- for idx in self.get_indexes(connection, table_name, schema=schema):
+ for idx in self.get_indexes(
+ connection, table_name, schema=schema, **kw
+ ):
if idx["name"] == index_name:
return True
else:
return False
+ def has_schema(
+ self, connection: Connection, schema_name: str, **kw: Any
+ ) -> bool:
+ return schema_name in self.get_schema_names(connection, **kw)
+
def validate_identifier(self, ident):
if len(ident) > self.max_identifier_length:
raise exc.IdentifierError(
@@ -769,6 +778,122 @@ class DefaultDialect(Dialect):
def get_driver_connection(self, connection):
return connection
+ def _overrides_default(self, method):
+ return (
+ getattr(type(self), method).__code__
+ is not getattr(DefaultDialect, method).__code__
+ )
+
+ def _default_multi_reflect(
+ self,
+ single_tbl_method,
+ connection,
+ kind,
+ schema,
+ filter_names,
+ scope,
+ **kw,
+ ):
+
+ names_fns = []
+ temp_names_fns = []
+ if ObjectKind.TABLE in kind:
+ names_fns.append(self.get_table_names)
+ temp_names_fns.append(self.get_temp_table_names)
+ if ObjectKind.VIEW in kind:
+ names_fns.append(self.get_view_names)
+ temp_names_fns.append(self.get_temp_view_names)
+ if ObjectKind.MATERIALIZED_VIEW in kind:
+ names_fns.append(self.get_materialized_view_names)
+ # no temp materialized view at the moment
+ # temp_names_fns.append(self.get_temp_materialized_view_names)
+
+ unreflectable = kw.pop("unreflectable", {})
+
+ if (
+ filter_names
+ and scope is ObjectScope.ANY
+ and kind is ObjectKind.ANY
+ ):
+ # if names are given and no qualification on type of table
+ # (i.e. the Table(..., autoload) case), take the names as given,
+ # don't run names queries. If a table does not exit
+ # NoSuchTableError is raised and it's skipped
+
+ # this also suits the case for mssql where we can reflect
+ # individual temp tables but there's no temp_names_fn
+ names = filter_names
+ else:
+ names = []
+ name_kw = {"schema": schema, **kw}
+ fns = []
+ if ObjectScope.DEFAULT in scope:
+ fns.extend(names_fns)
+ if ObjectScope.TEMPORARY in scope:
+ fns.extend(temp_names_fns)
+
+ for fn in fns:
+ try:
+ names.extend(fn(connection, **name_kw))
+ except NotImplementedError:
+ pass
+
+ if filter_names:
+ filter_names = set(filter_names)
+
+ # iterate over all the tables/views and call the single table method
+ for table in names:
+ if not filter_names or table in filter_names:
+ key = (schema, table)
+ try:
+ yield (
+ key,
+ single_tbl_method(
+ connection, table, schema=schema, **kw
+ ),
+ )
+ except exc.UnreflectableTableError as err:
+ if key not in unreflectable:
+ unreflectable[key] = err
+ except exc.NoSuchTableError:
+ pass
+
+ def get_multi_table_options(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_table_options, connection, **kw
+ )
+
+ def get_multi_columns(self, connection, **kw):
+ return self._default_multi_reflect(self.get_columns, connection, **kw)
+
+ def get_multi_pk_constraint(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_pk_constraint, connection, **kw
+ )
+
+ def get_multi_foreign_keys(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_foreign_keys, connection, **kw
+ )
+
+ def get_multi_indexes(self, connection, **kw):
+ return self._default_multi_reflect(self.get_indexes, connection, **kw)
+
+ def get_multi_unique_constraints(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_unique_constraints, connection, **kw
+ )
+
+ def get_multi_check_constraints(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_check_constraints, connection, **kw
+ )
+
+ def get_multi_table_comment(self, connection, **kw):
+ return self._default_multi_reflect(
+ self.get_table_comment, connection, **kw
+ )
+
class StrCompileDialect(DefaultDialect):
diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py
index b8e85b646..28ed03f99 100644
--- a/lib/sqlalchemy/engine/interfaces.py
+++ b/lib/sqlalchemy/engine/interfaces.py
@@ -15,7 +15,9 @@ from typing import Any
from typing import Awaitable
from typing import Callable
from typing import ClassVar
+from typing import Collection
from typing import Dict
+from typing import Iterable
from typing import List
from typing import Mapping
from typing import MutableMapping
@@ -324,7 +326,7 @@ class ReflectedColumn(TypedDict):
nullable: bool
"""column nullability"""
- default: str
+ default: Optional[str]
"""column default expression as a SQL string"""
autoincrement: NotRequired[bool]
@@ -343,11 +345,11 @@ class ReflectedColumn(TypedDict):
comment: NotRequired[Optional[str]]
"""comment for the column, if present"""
- computed: NotRequired[Optional[ReflectedComputed]]
+ computed: NotRequired[ReflectedComputed]
"""indicates this column is computed at insert (possibly update) time by
the database."""
- identity: NotRequired[Optional[ReflectedIdentity]]
+ identity: NotRequired[ReflectedIdentity]
"""indicates this column is an IDENTITY column"""
dialect_options: NotRequired[Dict[str, Any]]
@@ -390,6 +392,9 @@ class ReflectedUniqueConstraint(TypedDict):
column_names: List[str]
"""column names which comprise the constraint"""
+ duplicates_index: NotRequired[Optional[str]]
+ "Indicates if this unique constraint duplicates an index with this name"
+
dialect_options: NotRequired[Dict[str, Any]]
"""Additional dialect-specific options detected for this reflected
object"""
@@ -439,7 +444,7 @@ class ReflectedForeignKeyConstraint(TypedDict):
referred_columns: List[str]
"""referenced column names"""
- dialect_options: NotRequired[Dict[str, Any]]
+ options: NotRequired[Dict[str, Any]]
"""Additional dialect-specific options detected for this reflected
object"""
@@ -462,9 +467,8 @@ class ReflectedIndex(TypedDict):
unique: bool
"""whether or not the index has a unique flag"""
- duplicates_constraint: NotRequired[bool]
- """boolean indicating this index mirrors a unique constraint of the same
- name"""
+ duplicates_constraint: NotRequired[Optional[str]]
+ "Indicates if this index mirrors a unique constraint with this name"
include_columns: NotRequired[List[str]]
"""columns to include in the INCLUDE clause for supporting databases.
@@ -472,7 +476,7 @@ class ReflectedIndex(TypedDict):
.. deprecated:: 2.0
Legacy value, will be replaced with
- ``d["dialect_options"][<dialect name>]["include"]``
+ ``d["dialect_options"]["<dialect name>_include"]``
"""
@@ -494,7 +498,7 @@ class ReflectedTableComment(TypedDict):
"""
- text: str
+ text: Optional[str]
"""text of the comment"""
@@ -547,6 +551,7 @@ class BindTyping(Enum):
VersionInfoType = Tuple[Union[int, str], ...]
+TableKey = Tuple[Optional[str], str]
class Dialect(EventTarget):
@@ -1040,7 +1045,7 @@ class Dialect(EventTarget):
raise NotImplementedError()
- def initialize(self, connection: "Connection") -> None:
+ def initialize(self, connection: Connection) -> None:
"""Called during strategized creation of the dialect with a
connection.
@@ -1060,9 +1065,14 @@ class Dialect(EventTarget):
pass
+ if TYPE_CHECKING:
+
+ def _overrides_default(self, method_name: str) -> bool:
+ ...
+
def get_columns(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
@@ -1074,13 +1084,40 @@ class Dialect(EventTarget):
information as a list of dictionaries
corresponding to the :class:`.ReflectedColumn` dictionary.
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_columns`.
+ """
+
+ def get_multi_columns(
+ self,
+ connection: Connection,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ **kw: Any,
+ ) -> Iterable[Tuple[TableKey, List[ReflectedColumn]]]:
+ """Return information about columns in all tables in the
+ given ``schema``.
+
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_multi_columns`.
+
+ .. note:: The :class:`_engine.DefaultDialect` provides a default
+ implementation that will call the single table method for
+ each object returned by :meth:`Dialect.get_table_names`,
+ :meth:`Dialect.get_view_names` or
+ :meth:`Dialect.get_materialized_view_names` depending on the
+ provided ``kind``. Dialects that want to support a faster
+ implementation should implement this method.
+
+ .. versionadded:: 2.0
+
"""
raise NotImplementedError()
def get_pk_constraint(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
@@ -1093,13 +1130,41 @@ class Dialect(EventTarget):
key information as a dictionary corresponding to the
:class:`.ReflectedPrimaryKeyConstraint` dictionary.
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_pk_constraint`.
+
+ """
+ raise NotImplementedError()
+
+ def get_multi_pk_constraint(
+ self,
+ connection: Connection,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ **kw: Any,
+ ) -> Iterable[Tuple[TableKey, ReflectedPrimaryKeyConstraint]]:
+ """Return information about primary key constraints in
+ all tables in the given ``schema``.
+
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_multi_pk_constraint`.
+
+ .. note:: The :class:`_engine.DefaultDialect` provides a default
+ implementation that will call the single table method for
+ each object returned by :meth:`Dialect.get_table_names`,
+ :meth:`Dialect.get_view_names` or
+ :meth:`Dialect.get_materialized_view_names` depending on the
+ provided ``kind``. Dialects that want to support a faster
+ implementation should implement this method.
+
+ .. versionadded:: 2.0
"""
raise NotImplementedError()
def get_foreign_keys(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
@@ -1111,42 +1176,104 @@ class Dialect(EventTarget):
key information as a list of dicts corresponding to the
:class:`.ReflectedForeignKeyConstraint` dictionary.
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_foreign_keys`.
+ """
+
+ raise NotImplementedError()
+
+ def get_multi_foreign_keys(
+ self,
+ connection: Connection,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ **kw: Any,
+ ) -> Iterable[Tuple[TableKey, List[ReflectedForeignKeyConstraint]]]:
+ """Return information about foreign_keys in all tables
+ in the given ``schema``.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_multi_foreign_keys`.
+
+ .. note:: The :class:`_engine.DefaultDialect` provides a default
+ implementation that will call the single table method for
+ each object returned by :meth:`Dialect.get_table_names`,
+ :meth:`Dialect.get_view_names` or
+ :meth:`Dialect.get_materialized_view_names` depending on the
+ provided ``kind``. Dialects that want to support a faster
+ implementation should implement this method.
+
+ .. versionadded:: 2.0
+
"""
raise NotImplementedError()
def get_table_names(
- self, connection: "Connection", schema: Optional[str] = None, **kw: Any
+ self, connection: Connection, schema: Optional[str] = None, **kw: Any
) -> List[str]:
- """Return a list of table names for ``schema``."""
+ """Return a list of table names for ``schema``.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_table_names`.
+
+ """
raise NotImplementedError()
def get_temp_table_names(
- self, connection: "Connection", schema: Optional[str] = None, **kw: Any
+ self, connection: Connection, schema: Optional[str] = None, **kw: Any
) -> List[str]:
"""Return a list of temporary table names on the given connection,
if supported by the underlying backend.
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_temp_table_names`.
+
"""
raise NotImplementedError()
def get_view_names(
- self, connection: "Connection", schema: Optional[str] = None, **kw: Any
+ self, connection: Connection, schema: Optional[str] = None, **kw: Any
+ ) -> List[str]:
+ """Return a list of all non-materialized view names available in the
+ database.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_view_names`.
+
+ :param schema: schema name to query, if not the default schema.
+
+ """
+
+ raise NotImplementedError()
+
+ def get_materialized_view_names(
+ self, connection: Connection, schema: Optional[str] = None, **kw: Any
) -> List[str]:
- """Return a list of all view names available in the database.
+ """Return a list of all materialized view names available in the
+ database.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_materialized_view_names`.
:param schema: schema name to query, if not the default schema.
+
+ .. versionadded:: 2.0
+
"""
raise NotImplementedError()
def get_sequence_names(
- self, connection: "Connection", schema: Optional[str] = None, **kw: Any
+ self, connection: Connection, schema: Optional[str] = None, **kw: Any
) -> List[str]:
"""Return a list of all sequence names available in the database.
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_sequence_names`.
+
:param schema: schema name to query, if not the default schema.
.. versionadded:: 1.4
@@ -1155,26 +1282,40 @@ class Dialect(EventTarget):
raise NotImplementedError()
def get_temp_view_names(
- self, connection: "Connection", schema: Optional[str] = None, **kw: Any
+ self, connection: Connection, schema: Optional[str] = None, **kw: Any
) -> List[str]:
"""Return a list of temporary view names on the given connection,
if supported by the underlying backend.
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_temp_view_names`.
+
"""
raise NotImplementedError()
+ def get_schema_names(self, connection: Connection, **kw: Any) -> List[str]:
+ """Return a list of all schema names available in the database.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_schema_names`.
+ """
+ raise NotImplementedError()
+
def get_view_definition(
self,
- connection: "Connection",
+ connection: Connection,
view_name: str,
schema: Optional[str] = None,
**kw: Any,
) -> str:
- """Return view definition.
+ """Return plain or materialized view definition.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_view_definition`.
Given a :class:`_engine.Connection`, a string
- `view_name`, and an optional string ``schema``, return the view
+ ``view_name``, and an optional string ``schema``, return the view
definition.
"""
@@ -1182,7 +1323,7 @@ class Dialect(EventTarget):
def get_indexes(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
@@ -1194,13 +1335,42 @@ class Dialect(EventTarget):
information as a list of dictionaries corresponding to the
:class:`.ReflectedIndex` dictionary.
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_indexes`.
+ """
+
+ raise NotImplementedError()
+
+ def get_multi_indexes(
+ self,
+ connection: Connection,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ **kw: Any,
+ ) -> Iterable[Tuple[TableKey, List[ReflectedIndex]]]:
+ """Return information about indexes in in all tables
+ in the given ``schema``.
+
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_multi_indexes`.
+
+ .. note:: The :class:`_engine.DefaultDialect` provides a default
+ implementation that will call the single table method for
+ each object returned by :meth:`Dialect.get_table_names`,
+ :meth:`Dialect.get_view_names` or
+ :meth:`Dialect.get_materialized_view_names` depending on the
+ provided ``kind``. Dialects that want to support a faster
+ implementation should implement this method.
+
+ .. versionadded:: 2.0
+
"""
raise NotImplementedError()
def get_unique_constraints(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
@@ -1211,13 +1381,42 @@ class Dialect(EventTarget):
unique constraint information as a list of dicts corresponding
to the :class:`.ReflectedUniqueConstraint` dictionary.
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_unique_constraints`.
+ """
+
+ raise NotImplementedError()
+
+ def get_multi_unique_constraints(
+ self,
+ connection: Connection,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ **kw: Any,
+ ) -> Iterable[Tuple[TableKey, List[ReflectedUniqueConstraint]]]:
+ """Return information about unique constraints in all tables
+ in the given ``schema``.
+
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_multi_unique_constraints`.
+
+ .. note:: The :class:`_engine.DefaultDialect` provides a default
+ implementation that will call the single table method for
+ each object returned by :meth:`Dialect.get_table_names`,
+ :meth:`Dialect.get_view_names` or
+ :meth:`Dialect.get_materialized_view_names` depending on the
+ provided ``kind``. Dialects that want to support a faster
+ implementation should implement this method.
+
+ .. versionadded:: 2.0
+
"""
raise NotImplementedError()
def get_check_constraints(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
@@ -1228,26 +1427,86 @@ class Dialect(EventTarget):
check constraint information as a list of dicts corresponding
to the :class:`.ReflectedCheckConstraint` dictionary.
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_check_constraints`.
+
+ .. versionadded:: 1.1.0
+
+ """
+
+ raise NotImplementedError()
+
+ def get_multi_check_constraints(
+ self,
+ connection: Connection,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ **kw: Any,
+ ) -> Iterable[Tuple[TableKey, List[ReflectedCheckConstraint]]]:
+ """Return information about check constraints in all tables
+ in the given ``schema``.
+
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_multi_check_constraints`.
+
+ .. note:: The :class:`_engine.DefaultDialect` provides a default
+ implementation that will call the single table method for
+ each object returned by :meth:`Dialect.get_table_names`,
+ :meth:`Dialect.get_view_names` or
+ :meth:`Dialect.get_materialized_view_names` depending on the
+ provided ``kind``. Dialects that want to support a faster
+ implementation should implement this method.
+
+ .. versionadded:: 2.0
+
"""
raise NotImplementedError()
def get_table_options(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
- ) -> Optional[Dict[str, Any]]:
- r"""Return the "options" for the table identified by ``table_name``
- as a dictionary.
+ ) -> Dict[str, Any]:
+ """Return a dictionary of options specified when ``table_name``
+ was created.
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_table_options`.
"""
- return None
+ raise NotImplementedError()
+
+ def get_multi_table_options(
+ self,
+ connection: Connection,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ **kw: Any,
+ ) -> Iterable[Tuple[TableKey, Dict[str, Any]]]:
+ """Return a dictionary of options specified when the tables in the
+ given schema were created.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_multi_table_options`.
+
+ .. note:: The :class:`_engine.DefaultDialect` provides a default
+ implementation that will call the single table method for
+ each object returned by :meth:`Dialect.get_table_names`,
+ :meth:`Dialect.get_view_names` or
+ :meth:`Dialect.get_materialized_view_names` depending on the
+ provided ``kind``. Dialects that want to support a faster
+ implementation should implement this method.
+
+ .. versionadded:: 2.0
+
+ """
+ raise NotImplementedError()
def get_table_comment(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
@@ -1258,6 +1517,8 @@ class Dialect(EventTarget):
table comment information as a dictionary corresponding to the
:class:`.ReflectedTableComment` dictionary.
+ This is an internal dialect method. Applications should use
+ :meth:`.Inspector.get_table_comment`.
:raise: ``NotImplementedError`` for dialects that don't support
comments.
@@ -1268,6 +1529,33 @@ class Dialect(EventTarget):
raise NotImplementedError()
+ def get_multi_table_comment(
+ self,
+ connection: Connection,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ **kw: Any,
+ ) -> Iterable[Tuple[TableKey, ReflectedTableComment]]:
+ """Return information about the table comment in all tables
+ in the given ``schema``.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.get_multi_table_comment`.
+
+ .. note:: The :class:`_engine.DefaultDialect` provides a default
+ implementation that will call the single table method for
+ each object returned by :meth:`Dialect.get_table_names`,
+ :meth:`Dialect.get_view_names` or
+ :meth:`Dialect.get_materialized_view_names` depending on the
+ provided ``kind``. Dialects that want to support a faster
+ implementation should implement this method.
+
+ .. versionadded:: 2.0
+
+ """
+
+ raise NotImplementedError()
+
def normalize_name(self, name: str) -> str:
"""convert the given name to lowercase if it is detected as
case insensitive.
@@ -1290,7 +1578,7 @@ class Dialect(EventTarget):
def has_table(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
schema: Optional[str] = None,
**kw: Any,
@@ -1327,21 +1615,24 @@ class Dialect(EventTarget):
def has_index(
self,
- connection: "Connection",
+ connection: Connection,
table_name: str,
index_name: str,
schema: Optional[str] = None,
+ **kw: Any,
) -> bool:
"""Check the existence of a particular index name in the database.
Given a :class:`_engine.Connection` object, a string
- ``table_name`` and string index name, return True if an index of the
- given name on the given table exists, false otherwise.
+ ``table_name`` and string index name, return ``True`` if an index of
+ the given name on the given table exists, ``False`` otherwise.
The :class:`.DefaultDialect` implements this in terms of the
:meth:`.Dialect.has_table` and :meth:`.Dialect.get_indexes` methods,
however dialects can implement a more performant version.
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.has_index`.
.. versionadded:: 1.4
@@ -1351,7 +1642,7 @@ class Dialect(EventTarget):
def has_sequence(
self,
- connection: "Connection",
+ connection: Connection,
sequence_name: str,
schema: Optional[str] = None,
**kw: Any,
@@ -1359,13 +1650,39 @@ class Dialect(EventTarget):
"""Check the existence of a particular sequence in the database.
Given a :class:`_engine.Connection` object and a string
- `sequence_name`, return True if the given sequence exists in
- the database, False otherwise.
+ `sequence_name`, return ``True`` if the given sequence exists in
+ the database, ``False`` otherwise.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.has_sequence`.
+ """
+
+ raise NotImplementedError()
+
+ def has_schema(
+ self, connection: Connection, schema_name: str, **kw: Any
+ ) -> bool:
+ """Check the existence of a particular schema name in the database.
+
+ Given a :class:`_engine.Connection` object, a string
+ ``schema_name``, return ``True`` if a schema of the
+ given exists, ``False`` otherwise.
+
+ The :class:`.DefaultDialect` implements this by checking
+ the presence of ``schema_name`` among the schemas returned by
+ :meth:`.Dialect.get_schema_names`,
+ however dialects can implement a more performant version.
+
+ This is an internal dialect method. Applications should use
+ :meth:`_engine.Inspector.has_schema`.
+
+ .. versionadded:: 2.0
+
"""
raise NotImplementedError()
- def _get_server_version_info(self, connection: "Connection") -> Any:
+ def _get_server_version_info(self, connection: Connection) -> Any:
"""Retrieve the server version info from the given connection.
This is used by the default implementation to populate the
@@ -1376,7 +1693,7 @@ class Dialect(EventTarget):
raise NotImplementedError()
- def _get_default_schema_name(self, connection: "Connection") -> str:
+ def _get_default_schema_name(self, connection: Connection) -> str:
"""Return the string name of the currently selected schema from
the given connection.
@@ -1481,7 +1798,7 @@ class Dialect(EventTarget):
raise NotImplementedError()
- def do_savepoint(self, connection: "Connection", name: str) -> None:
+ def do_savepoint(self, connection: Connection, name: str) -> None:
"""Create a savepoint with the given name.
:param connection: a :class:`_engine.Connection`.
@@ -1492,7 +1809,7 @@ class Dialect(EventTarget):
raise NotImplementedError()
def do_rollback_to_savepoint(
- self, connection: "Connection", name: str
+ self, connection: Connection, name: str
) -> None:
"""Rollback a connection to the named savepoint.
@@ -1503,9 +1820,7 @@ class Dialect(EventTarget):
raise NotImplementedError()
- def do_release_savepoint(
- self, connection: "Connection", name: str
- ) -> None:
+ def do_release_savepoint(self, connection: Connection, name: str) -> None:
"""Release the named savepoint on a connection.
:param connection: a :class:`_engine.Connection`.
@@ -1514,7 +1829,7 @@ class Dialect(EventTarget):
raise NotImplementedError()
- def do_begin_twophase(self, connection: "Connection", xid: Any) -> None:
+ def do_begin_twophase(self, connection: Connection, xid: Any) -> None:
"""Begin a two phase transaction on the given connection.
:param connection: a :class:`_engine.Connection`.
@@ -1524,7 +1839,7 @@ class Dialect(EventTarget):
raise NotImplementedError()
- def do_prepare_twophase(self, connection: "Connection", xid: Any) -> None:
+ def do_prepare_twophase(self, connection: Connection, xid: Any) -> None:
"""Prepare a two phase transaction on the given connection.
:param connection: a :class:`_engine.Connection`.
@@ -1536,7 +1851,7 @@ class Dialect(EventTarget):
def do_rollback_twophase(
self,
- connection: "Connection",
+ connection: Connection,
xid: Any,
is_prepared: bool = True,
recover: bool = False,
@@ -1555,7 +1870,7 @@ class Dialect(EventTarget):
def do_commit_twophase(
self,
- connection: "Connection",
+ connection: Connection,
xid: Any,
is_prepared: bool = True,
recover: bool = False,
@@ -1573,7 +1888,7 @@ class Dialect(EventTarget):
raise NotImplementedError()
- def do_recover_twophase(self, connection: "Connection") -> List[Any]:
+ def do_recover_twophase(self, connection: Connection) -> List[Any]:
"""Recover list of uncommitted prepared two phase transaction
identifiers on the given connection.
diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py
index 4fc57d5f4..32c89106b 100644
--- a/lib/sqlalchemy/engine/reflection.py
+++ b/lib/sqlalchemy/engine/reflection.py
@@ -27,39 +27,148 @@ methods such as get_table_names, get_columns, etc.
from __future__ import annotations
import contextlib
+from dataclasses import dataclass
+from enum import auto
+from enum import Flag
+from enum import unique
+from typing import Any
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import Generator
+from typing import Iterable
from typing import List
from typing import Optional
+from typing import Sequence
+from typing import Set
+from typing import Tuple
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
from .base import Connection
from .base import Engine
-from .interfaces import ReflectedColumn
from .. import exc
from .. import inspection
from .. import sql
from .. import util
from ..sql import operators
from ..sql import schema as sa_schema
+from ..sql.cache_key import _ad_hoc_cache_key_from_args
+from ..sql.elements import TextClause
from ..sql.type_api import TypeEngine
+from ..sql.visitors import InternalTraversal
from ..util import topological
+from ..util.typing import final
+
+if TYPE_CHECKING:
+ from .interfaces import Dialect
+ from .interfaces import ReflectedCheckConstraint
+ from .interfaces import ReflectedColumn
+ from .interfaces import ReflectedForeignKeyConstraint
+ from .interfaces import ReflectedIndex
+ from .interfaces import ReflectedPrimaryKeyConstraint
+ from .interfaces import ReflectedTableComment
+ from .interfaces import ReflectedUniqueConstraint
+ from .interfaces import TableKey
+
+_R = TypeVar("_R")
@util.decorator
-def cache(fn, self, con, *args, **kw):
+def cache(
+ fn: Callable[..., _R],
+ self: Dialect,
+ con: Connection,
+ *args: Any,
+ **kw: Any,
+) -> _R:
info_cache = kw.get("info_cache", None)
if info_cache is None:
return fn(self, con, *args, **kw)
+ exclude = {"info_cache", "unreflectable"}
key = (
fn.__name__,
tuple(a for a in args if isinstance(a, str)),
- tuple((k, v) for k, v in kw.items() if k != "info_cache"),
+ tuple((k, v) for k, v in kw.items() if k not in exclude),
)
- ret = info_cache.get(key)
+ ret: _R = info_cache.get(key)
if ret is None:
ret = fn(self, con, *args, **kw)
info_cache[key] = ret
return ret
+def flexi_cache(
+ *traverse_args: Tuple[str, InternalTraversal]
+) -> Callable[[Callable[..., _R]], Callable[..., _R]]:
+ @util.decorator
+ def go(
+ fn: Callable[..., _R],
+ self: Dialect,
+ con: Connection,
+ *args: Any,
+ **kw: Any,
+ ) -> _R:
+ info_cache = kw.get("info_cache", None)
+ if info_cache is None:
+ return fn(self, con, *args, **kw)
+ key = _ad_hoc_cache_key_from_args((fn.__name__,), traverse_args, args)
+ ret: _R = info_cache.get(key)
+ if ret is None:
+ ret = fn(self, con, *args, **kw)
+ info_cache[key] = ret
+ return ret
+
+ return go
+
+
+@unique
+class ObjectKind(Flag):
+ """Enumerator that indicates which kind of object to return when calling
+ the ``get_multi`` methods.
+
+ This is a Flag enum, so custom combinations can be passed. For example,
+ to reflect tables and plain views ``ObjectKind.TABLE | ObjectKind.VIEW``
+ may be used.
+
+ .. note::
+ Not all dialect may support all kind of object. If a dialect does
+ not support a particular object an empty dict is returned.
+ In case a dialect supports an object, but the requested method
+ is not applicable for the specified kind the default value
+ will be returned for each reflected object. For example reflecting
+ check constraints of view return a dict with all the views with
+ empty lists as values.
+ """
+
+ TABLE = auto()
+ "Reflect table objects"
+ VIEW = auto()
+ "Reflect plain view objects"
+ MATERIALIZED_VIEW = auto()
+ "Reflect materialized view object"
+
+ ANY_VIEW = VIEW | MATERIALIZED_VIEW
+ "Reflect any kind of view objects"
+ ANY = TABLE | VIEW | MATERIALIZED_VIEW
+ "Reflect all type of objects"
+
+
+@unique
+class ObjectScope(Flag):
+ """Enumerator that indicates which scope to use when calling
+ the ``get_multi`` methods.
+ """
+
+ DEFAULT = auto()
+ "Include default scope"
+ TEMPORARY = auto()
+ "Include only temp scope"
+ ANY = DEFAULT | TEMPORARY
+ "Include both default and temp scope"
+
+
@inspection._self_inspects
class Inspector(inspection.Inspectable["Inspector"]):
"""Performs database schema inspection.
@@ -85,6 +194,12 @@ class Inspector(inspection.Inspectable["Inspector"]):
"""
+ bind: Union[Engine, Connection]
+ engine: Engine
+ _op_context_requires_connect: bool
+ dialect: Dialect
+ info_cache: Dict[Any, Any]
+
@util.deprecated(
"1.4",
"The __init__() method on :class:`_reflection.Inspector` "
@@ -96,7 +211,7 @@ class Inspector(inspection.Inspectable["Inspector"]):
"in order to "
"acquire an :class:`_reflection.Inspector`.",
)
- def __init__(self, bind):
+ def __init__(self, bind: Union[Engine, Connection]):
"""Initialize a new :class:`_reflection.Inspector`.
:param bind: a :class:`~sqlalchemy.engine.Connection`,
@@ -108,38 +223,51 @@ class Inspector(inspection.Inspectable["Inspector"]):
:meth:`_reflection.Inspector.from_engine`
"""
- return self._init_legacy(bind)
+ self._init_legacy(bind)
@classmethod
- def _construct(cls, init, bind):
+ def _construct(
+ cls, init: Callable[..., Any], bind: Union[Engine, Connection]
+ ) -> Inspector:
if hasattr(bind.dialect, "inspector"):
- cls = bind.dialect.inspector
+ cls = bind.dialect.inspector # type: ignore[attr-defined]
self = cls.__new__(cls)
init(self, bind)
return self
- def _init_legacy(self, bind):
+ def _init_legacy(self, bind: Union[Engine, Connection]) -> None:
if hasattr(bind, "exec_driver_sql"):
- self._init_connection(bind)
+ self._init_connection(bind) # type: ignore[arg-type]
else:
- self._init_engine(bind)
+ self._init_engine(bind) # type: ignore[arg-type]
- def _init_engine(self, engine):
+ def _init_engine(self, engine: Engine) -> None:
self.bind = self.engine = engine
engine.connect().close()
self._op_context_requires_connect = True
self.dialect = self.engine.dialect
self.info_cache = {}
- def _init_connection(self, connection):
+ def _init_connection(self, connection: Connection) -> None:
self.bind = connection
self.engine = connection.engine
self._op_context_requires_connect = False
self.dialect = self.engine.dialect
self.info_cache = {}
+ def clear_cache(self) -> None:
+ """reset the cache for this :class:`.Inspector`.
+
+ Inspection methods that have data cached will emit SQL queries
+ when next called to get new data.
+
+ .. versionadded:: 2.0
+
+ """
+ self.info_cache.clear()
+
@classmethod
@util.deprecated(
"1.4",
@@ -152,7 +280,7 @@ class Inspector(inspection.Inspectable["Inspector"]):
"in order to "
"acquire an :class:`_reflection.Inspector`.",
)
- def from_engine(cls, bind):
+ def from_engine(cls, bind: Engine) -> Inspector:
"""Construct a new dialect-specific Inspector object from the given
engine or connection.
@@ -172,15 +300,15 @@ class Inspector(inspection.Inspectable["Inspector"]):
return cls._construct(cls._init_legacy, bind)
@inspection._inspects(Engine)
- def _engine_insp(bind):
+ def _engine_insp(bind: Engine) -> Inspector: # type: ignore[misc]
return Inspector._construct(Inspector._init_engine, bind)
@inspection._inspects(Connection)
- def _connection_insp(bind):
+ def _connection_insp(bind: Connection) -> Inspector: # type: ignore[misc]
return Inspector._construct(Inspector._init_connection, bind)
@contextlib.contextmanager
- def _operation_context(self):
+ def _operation_context(self) -> Generator[Connection, None, None]:
"""Return a context that optimizes for multiple operations on a single
transaction.
@@ -189,10 +317,11 @@ class Inspector(inspection.Inspectable["Inspector"]):
:class:`_engine.Connection`.
"""
+ conn: Connection
if self._op_context_requires_connect:
- conn = self.bind.connect()
+ conn = self.bind.connect() # type: ignore[union-attr]
else:
- conn = self.bind
+ conn = self.bind # type: ignore[assignment]
try:
yield conn
finally:
@@ -200,7 +329,7 @@ class Inspector(inspection.Inspectable["Inspector"]):
conn.close()
@contextlib.contextmanager
- def _inspection_context(self):
+ def _inspection_context(self) -> Generator[Inspector, None, None]:
"""Return an :class:`_reflection.Inspector`
from this one that will run all
operations on a single connection.
@@ -213,7 +342,7 @@ class Inspector(inspection.Inspectable["Inspector"]):
yield sub_insp
@property
- def default_schema_name(self):
+ def default_schema_name(self) -> Optional[str]:
"""Return the default schema name presented by the dialect
for the current engine's database user.
@@ -223,30 +352,38 @@ class Inspector(inspection.Inspectable["Inspector"]):
"""
return self.dialect.default_schema_name
- def get_schema_names(self):
- """Return all schema names."""
+ def get_schema_names(self, **kw: Any) -> List[str]:
+ r"""Return all schema names.
- if hasattr(self.dialect, "get_schema_names"):
- with self._operation_context() as conn:
- return self.dialect.get_schema_names(
- conn, info_cache=self.info_cache
- )
- return []
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+ """
- def get_table_names(self, schema=None):
- """Return all table names in referred to within a particular schema.
+ with self._operation_context() as conn:
+ return self.dialect.get_schema_names(
+ conn, info_cache=self.info_cache, **kw
+ )
+
+ def get_table_names(
+ self, schema: Optional[str] = None, **kw: Any
+ ) -> List[str]:
+ r"""Return all table names within a particular schema.
The names are expected to be real tables only, not views.
Views are instead returned using the
- :meth:`_reflection.Inspector.get_view_names`
- method.
-
+ :meth:`_reflection.Inspector.get_view_names` and/or
+ :meth:`_reflection.Inspector.get_materialized_view_names`
+ methods.
:param schema: Schema name. If ``schema`` is left at ``None``, the
database's default schema is
used, else the named schema is searched. If the database does not
support named schemas, behavior is undefined if ``schema`` is not
passed as ``None``. For special quoting, use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
.. seealso::
@@ -258,43 +395,105 @@ class Inspector(inspection.Inspectable["Inspector"]):
with self._operation_context() as conn:
return self.dialect.get_table_names(
- conn, schema, info_cache=self.info_cache
+ conn, schema, info_cache=self.info_cache, **kw
)
- def has_table(self, table_name, schema=None):
- """Return True if the backend has a table or view of the given name.
+ def has_table(
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> bool:
+ r"""Return True if the backend has a table or view of the given name.
:param table_name: name of the table to check
:param schema: schema name to query, if not the default schema.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
.. versionadded:: 1.4 - the :meth:`.Inspector.has_table` method
replaces the :meth:`_engine.Engine.has_table` method.
- .. versionchanged:: 2.0:: The method checks also for views.
+ .. versionchanged:: 2.0:: The method checks also for any type of
+ views (plain or materialized).
In previous version this behaviour was dialect specific. New
dialect suite tests were added to ensure all dialect conform with
this behaviour.
"""
- # TODO: info_cache?
with self._operation_context() as conn:
- return self.dialect.has_table(conn, table_name, schema)
+ return self.dialect.has_table(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
- def has_sequence(self, sequence_name, schema=None):
- """Return True if the backend has a table of the given name.
+ def has_sequence(
+ self, sequence_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> bool:
+ r"""Return True if the backend has a sequence with the given name.
- :param sequence_name: name of the table to check
+ :param sequence_name: name of the sequence to check
:param schema: schema name to query, if not the default schema.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
.. versionadded:: 1.4
"""
- # TODO: info_cache?
with self._operation_context() as conn:
- return self.dialect.has_sequence(conn, sequence_name, schema)
+ return self.dialect.has_sequence(
+ conn, sequence_name, schema, info_cache=self.info_cache, **kw
+ )
+
+ def has_index(
+ self,
+ table_name: str,
+ index_name: str,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> bool:
+ r"""Check the existence of a particular index name in the database.
+
+ :param table_name: the name of the table the index belongs to
+ :param index_name: the name of the index to check
+ :param schema: schema name to query, if not the default schema.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ .. versionadded:: 2.0
+
+ """
+ with self._operation_context() as conn:
+ return self.dialect.has_index(
+ conn,
+ table_name,
+ index_name,
+ schema,
+ info_cache=self.info_cache,
+ **kw,
+ )
+
+ def has_schema(self, schema_name: str, **kw: Any) -> bool:
+ r"""Return True if the backend has a schema with the given name.
+
+ :param schema_name: name of the schema to check
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ .. versionadded:: 2.0
+
+ """
+ with self._operation_context() as conn:
+ return self.dialect.has_schema(
+ conn, schema_name, info_cache=self.info_cache, **kw
+ )
- def get_sorted_table_and_fkc_names(self, schema=None):
- """Return dependency-sorted table and foreign key constraint names in
+ def get_sorted_table_and_fkc_names(
+ self,
+ schema: Optional[str] = None,
+ **kw: Any,
+ ) -> List[Tuple[Optional[str], List[Tuple[str, Optional[str]]]]]:
+ r"""Return dependency-sorted table and foreign key constraint names in
referred to within a particular schema.
This will yield 2-tuples of
@@ -309,6 +508,11 @@ class Inspector(inspection.Inspectable["Inspector"]):
.. versionadded:: 1.0.-
+ :param schema: schema name to query, if not the default schema.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
.. seealso::
:meth:`_reflection.Inspector.get_table_names`
@@ -317,24 +521,74 @@ class Inspector(inspection.Inspectable["Inspector"]):
with an already-given :class:`_schema.MetaData`.
"""
- with self._operation_context() as conn:
- tnames = self.dialect.get_table_names(
- conn, schema, info_cache=self.info_cache
+
+ return [
+ (
+ table_key[1] if table_key else None,
+ [(tname, fks) for (_, tname), fks in fk_collection],
)
+ for (
+ table_key,
+ fk_collection,
+ ) in self.sort_tables_on_foreign_key_dependency(
+ consider_schemas=(schema,)
+ )
+ ]
- tuples = set()
- remaining_fkcs = set()
+ def sort_tables_on_foreign_key_dependency(
+ self,
+ consider_schemas: Collection[Optional[str]] = (None,),
+ **kw: Any,
+ ) -> List[
+ Tuple[
+ Optional[Tuple[Optional[str], str]],
+ List[Tuple[Tuple[Optional[str], str], Optional[str]]],
+ ]
+ ]:
+ r"""Return dependency-sorted table and foreign key constraint names
+ referred to within multiple schemas.
+
+ This method may be compared to
+ :meth:`.Inspector.get_sorted_table_and_fkc_names`, which
+ works on one schema at a time; here, the method is a generalization
+ that will consider multiple schemas at once including that it will
+ resolve for cross-schema foreign keys.
+
+ .. versionadded:: 2.0
- fknames_for_table = {}
- for tname in tnames:
- fkeys = self.get_foreign_keys(tname, schema)
- fknames_for_table[tname] = set([fk["name"] for fk in fkeys])
- for fkey in fkeys:
- if tname != fkey["referred_table"]:
- tuples.add((fkey["referred_table"], tname))
+ """
+ SchemaTab = Tuple[Optional[str], str]
+
+ tuples: Set[Tuple[SchemaTab, SchemaTab]] = set()
+ remaining_fkcs: Set[Tuple[SchemaTab, Optional[str]]] = set()
+ fknames_for_table: Dict[SchemaTab, Set[Optional[str]]] = {}
+ tnames: List[SchemaTab] = []
+
+ for schname in consider_schemas:
+ schema_fkeys = self.get_multi_foreign_keys(schname, **kw)
+ tnames.extend(schema_fkeys)
+ for (_, tname), fkeys in schema_fkeys.items():
+ fknames_for_table[(schname, tname)] = set(
+ [fk["name"] for fk in fkeys]
+ )
+ for fkey in fkeys:
+ if (
+ tname != fkey["referred_table"]
+ or schname != fkey["referred_schema"]
+ ):
+ tuples.add(
+ (
+ (
+ fkey["referred_schema"],
+ fkey["referred_table"],
+ ),
+ (schname, tname),
+ )
+ )
try:
candidate_sort = list(topological.sort(tuples, tnames))
except exc.CircularDependencyError as err:
+ edge: Tuple[SchemaTab, SchemaTab]
for edge in err.edges:
tuples.remove(edge)
remaining_fkcs.update(
@@ -342,16 +596,32 @@ class Inspector(inspection.Inspectable["Inspector"]):
)
candidate_sort = list(topological.sort(tuples, tnames))
- return [
- (tname, fknames_for_table[tname].difference(remaining_fkcs))
- for tname in candidate_sort
- ] + [(None, list(remaining_fkcs))]
+ ret: List[
+ Tuple[Optional[SchemaTab], List[Tuple[SchemaTab, Optional[str]]]]
+ ]
+ ret = [
+ (
+ (schname, tname),
+ [
+ ((schname, tname), fk)
+ for fk in fknames_for_table[(schname, tname)].difference(
+ name for _, name in remaining_fkcs
+ )
+ ],
+ )
+ for (schname, tname) in candidate_sort
+ ]
+ return ret + [(None, list(remaining_fkcs))]
- def get_temp_table_names(self):
- """Return a list of temporary table names for the current bind.
+ def get_temp_table_names(self, **kw: Any) -> List[str]:
+ r"""Return a list of temporary table names for the current bind.
This method is unsupported by most dialects; currently
- only SQLite implements it.
+ only Oracle, PostgreSQL and SQLite implements it.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
.. versionadded:: 1.0.0
@@ -359,28 +629,35 @@ class Inspector(inspection.Inspectable["Inspector"]):
with self._operation_context() as conn:
return self.dialect.get_temp_table_names(
- conn, info_cache=self.info_cache
+ conn, info_cache=self.info_cache, **kw
)
- def get_temp_view_names(self):
- """Return a list of temporary view names for the current bind.
+ def get_temp_view_names(self, **kw: Any) -> List[str]:
+ r"""Return a list of temporary view names for the current bind.
This method is unsupported by most dialects; currently
- only SQLite implements it.
+ only PostgreSQL and SQLite implements it.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
.. versionadded:: 1.0.0
"""
with self._operation_context() as conn:
return self.dialect.get_temp_view_names(
- conn, info_cache=self.info_cache
+ conn, info_cache=self.info_cache, **kw
)
- def get_table_options(self, table_name, schema=None, **kw):
- """Return a dictionary of options specified when the table of the
+ def get_table_options(
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> Dict[str, Any]:
+ r"""Return a dictionary of options specified when the table of the
given name was created.
- This currently includes some options that apply to MySQL tables.
+ This currently includes some options that apply to MySQL and Oracle
+ tables.
:param table_name: string name of the table. For special quoting,
use :class:`.quoted_name`.
@@ -389,60 +666,172 @@ class Inspector(inspection.Inspectable["Inspector"]):
of the database connection. For special quoting,
use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dict with the table options. The returned keys depend on the
+ dialect in use. Each one is prefixed with the dialect name.
+
"""
- if hasattr(self.dialect, "get_table_options"):
- with self._operation_context() as conn:
- return self.dialect.get_table_options(
- conn, table_name, schema, info_cache=self.info_cache, **kw
- )
- return {}
+ with self._operation_context() as conn:
+ return self.dialect.get_table_options(
+ conn, table_name, schema, info_cache=self.info_cache, **kw
+ )
+
+ def get_multi_table_options(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Sequence[str]] = None,
+ kind: ObjectKind = ObjectKind.TABLE,
+ scope: ObjectScope = ObjectScope.DEFAULT,
+ **kw: Any,
+ ) -> Dict[TableKey, Dict[str, Any]]:
+ r"""Return a dictionary of options specified when the tables in the
+ given schema were created.
+
+ The tables can be filtered by passing the names to use to
+ ``filter_names``.
+
+ This currently includes some options that apply to MySQL and Oracle
+ tables.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param filter_names: optionally return information only for the
+ objects listed here.
+
+ :param kind: a :class:`.ObjectKind` that specifies the type of objects
+ to reflect. Defaults to ``ObjectKind.TABLE``.
+
+ :param scope: a :class:`.ObjectScope` that specifies if options of
+ default, temporary or any tables should be reflected.
+ Defaults to ``ObjectScope.DEFAULT``.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary where the keys are two-tuple schema,table-name
+ and the values are dictionaries with the table options.
+ The returned keys in each dict depend on the
+ dialect in use. Each one is prefixed with the dialect name.
+ The schema is ``None`` if no schema is provided.
+
+ .. versionadded:: 2.0
+ """
+ with self._operation_context() as conn:
+ res = self.dialect.get_multi_table_options(
+ conn,
+ schema=schema,
+ filter_names=filter_names,
+ kind=kind,
+ scope=scope,
+ info_cache=self.info_cache,
+ **kw,
+ )
+ return dict(res)
- def get_view_names(self, schema=None):
- """Return all view names in `schema`.
+ def get_view_names(
+ self, schema: Optional[str] = None, **kw: Any
+ ) -> List[str]:
+ r"""Return all non-materialized view names in `schema`.
:param schema: Optional, retrieve names from a non-default schema.
For special quoting, use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+
+ .. versionchanged:: 2.0 For those dialects that previously included
+ the names of materialized views in this list (currently PostgreSQL),
+ this method no longer returns the names of materialized views.
+ the :meth:`.Inspector.get_materialized_view_names` method should
+ be used instead.
+
+ .. seealso::
+
+ :meth:`.Inspector.get_materialized_view_names`
"""
with self._operation_context() as conn:
return self.dialect.get_view_names(
- conn, schema, info_cache=self.info_cache
+ conn, schema, info_cache=self.info_cache, **kw
+ )
+
+ def get_materialized_view_names(
+ self, schema: Optional[str] = None, **kw: Any
+ ) -> List[str]:
+ r"""Return all materialized view names in `schema`.
+
+ :param schema: Optional, retrieve names from a non-default schema.
+ For special quoting, use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ .. versionadded:: 2.0
+
+ .. seealso::
+
+ :meth:`.Inspector.get_view_names`
+
+ """
+
+ with self._operation_context() as conn:
+ return self.dialect.get_materialized_view_names(
+ conn, schema, info_cache=self.info_cache, **kw
)
- def get_sequence_names(self, schema=None):
- """Return all sequence names in `schema`.
+ def get_sequence_names(
+ self, schema: Optional[str] = None, **kw: Any
+ ) -> List[str]:
+ r"""Return all sequence names in `schema`.
:param schema: Optional, retrieve names from a non-default schema.
For special quoting, use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
"""
with self._operation_context() as conn:
return self.dialect.get_sequence_names(
- conn, schema, info_cache=self.info_cache
+ conn, schema, info_cache=self.info_cache, **kw
)
- def get_view_definition(self, view_name, schema=None):
- """Return definition for `view_name`.
+ def get_view_definition(
+ self, view_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> str:
+ r"""Return definition for the plain or materialized view called
+ ``view_name``.
+ :param view_name: Name of the view.
:param schema: Optional, retrieve names from a non-default schema.
For special quoting, use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
"""
with self._operation_context() as conn:
return self.dialect.get_view_definition(
- conn, view_name, schema, info_cache=self.info_cache
+ conn, view_name, schema, info_cache=self.info_cache, **kw
)
def get_columns(
- self, table_name: str, schema: Optional[str] = None, **kw
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
) -> List[ReflectedColumn]:
- """Return information about columns in `table_name`.
+ r"""Return information about columns in ``table_name``.
- Given a string `table_name` and an optional string `schema`, return
- column information as a list of dicts with these keys:
+ Given a string ``table_name`` and an optional string ``schema``,
+ return column information as a list of dicts with these keys:
* ``name`` - the column's name
@@ -487,6 +876,10 @@ class Inspector(inspection.Inspectable["Inspector"]):
of the database connection. For special quoting,
use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
:return: list of dictionaries, each representing the definition of
a database column.
@@ -496,17 +889,83 @@ class Inspector(inspection.Inspectable["Inspector"]):
col_defs = self.dialect.get_columns(
conn, table_name, schema, info_cache=self.info_cache, **kw
)
- for col_def in col_defs:
- # make this easy and only return instances for coltype
- coltype = col_def["type"]
- if not isinstance(coltype, TypeEngine):
- col_def["type"] = coltype()
+ if col_defs:
+ self._instantiate_types([col_defs])
return col_defs
- def get_pk_constraint(self, table_name, schema=None, **kw):
- """Return information about primary key constraint on `table_name`.
+ def _instantiate_types(
+ self, data: Iterable[List[ReflectedColumn]]
+ ) -> None:
+ # make this easy and only return instances for coltype
+ for col_defs in data:
+ for col_def in col_defs:
+ coltype = col_def["type"]
+ if not isinstance(coltype, TypeEngine):
+ col_def["type"] = coltype()
+
+ def get_multi_columns(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Sequence[str]] = None,
+ kind: ObjectKind = ObjectKind.TABLE,
+ scope: ObjectScope = ObjectScope.DEFAULT,
+ **kw: Any,
+ ) -> Dict[TableKey, List[ReflectedColumn]]:
+ r"""Return information about columns in all objects in the given schema.
+
+ The objects can be filtered by passing the names to use to
+ ``filter_names``.
+
+ The column information is as described in
+ :meth:`Inspector.get_columns`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param filter_names: optionally return information only for the
+ objects listed here.
+
+ :param kind: a :class:`.ObjectKind` that specifies the type of objects
+ to reflect. Defaults to ``ObjectKind.TABLE``.
+
+ :param scope: a :class:`.ObjectScope` that specifies if columns of
+ default, temporary or any tables should be reflected.
+ Defaults to ``ObjectScope.DEFAULT``.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary where the keys are two-tuple schema,table-name
+ and the values are list of dictionaries, each representing the
+ definition of a database column.
+ The schema is ``None`` if no schema is provided.
+
+ .. versionadded:: 2.0
+ """
+
+ with self._operation_context() as conn:
+ table_col_defs = dict(
+ self.dialect.get_multi_columns(
+ conn,
+ schema=schema,
+ filter_names=filter_names,
+ kind=kind,
+ scope=scope,
+ info_cache=self.info_cache,
+ **kw,
+ )
+ )
+ self._instantiate_types(table_col_defs.values())
+ return table_col_defs
+
+ def get_pk_constraint(
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> ReflectedPrimaryKeyConstraint:
+ r"""Return information about primary key constraint in ``table_name``.
- Given a string `table_name`, and an optional string `schema`, return
+ Given a string ``table_name``, and an optional string `schema`, return
primary key information as a dictionary with these keys:
* ``constrained_columns`` -
@@ -522,16 +981,80 @@ class Inspector(inspection.Inspectable["Inspector"]):
of the database connection. For special quoting,
use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary representing the definition of
+ a primary key constraint.
+
"""
with self._operation_context() as conn:
return self.dialect.get_pk_constraint(
conn, table_name, schema, info_cache=self.info_cache, **kw
)
- def get_foreign_keys(self, table_name, schema=None, **kw):
- """Return information about foreign_keys in `table_name`.
+ def get_multi_pk_constraint(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Sequence[str]] = None,
+ kind: ObjectKind = ObjectKind.TABLE,
+ scope: ObjectScope = ObjectScope.DEFAULT,
+ **kw: Any,
+ ) -> Dict[TableKey, ReflectedPrimaryKeyConstraint]:
+ r"""Return information about primary key constraints in
+ all tables in the given schema.
+
+ The tables can be filtered by passing the names to use to
+ ``filter_names``.
+
+ The primary key information is as described in
+ :meth:`Inspector.get_pk_constraint`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param filter_names: optionally return information only for the
+ objects listed here.
+
+ :param kind: a :class:`.ObjectKind` that specifies the type of objects
+ to reflect. Defaults to ``ObjectKind.TABLE``.
+
+ :param scope: a :class:`.ObjectScope` that specifies if primary keys of
+ default, temporary or any tables should be reflected.
+ Defaults to ``ObjectScope.DEFAULT``.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary where the keys are two-tuple schema,table-name
+ and the values are dictionaries, each representing the
+ definition of a primary key constraint.
+ The schema is ``None`` if no schema is provided.
+
+ .. versionadded:: 2.0
+ """
+ with self._operation_context() as conn:
+ return dict(
+ self.dialect.get_multi_pk_constraint(
+ conn,
+ schema=schema,
+ filter_names=filter_names,
+ kind=kind,
+ scope=scope,
+ info_cache=self.info_cache,
+ **kw,
+ )
+ )
+
+ def get_foreign_keys(
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> List[ReflectedForeignKeyConstraint]:
+ r"""Return information about foreign_keys in ``table_name``.
- Given a string `table_name`, and an optional string `schema`, return
+ Given a string ``table_name``, and an optional string `schema`, return
foreign key information as a list of dicts with these keys:
* ``constrained_columns`` -
@@ -557,6 +1080,13 @@ class Inspector(inspection.Inspectable["Inspector"]):
of the database connection. For special quoting,
use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a list of dictionaries, each representing the
+ a foreign key definition.
+
"""
with self._operation_context() as conn:
@@ -564,10 +1094,68 @@ class Inspector(inspection.Inspectable["Inspector"]):
conn, table_name, schema, info_cache=self.info_cache, **kw
)
- def get_indexes(self, table_name, schema=None, **kw):
- """Return information about indexes in `table_name`.
+ def get_multi_foreign_keys(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Sequence[str]] = None,
+ kind: ObjectKind = ObjectKind.TABLE,
+ scope: ObjectScope = ObjectScope.DEFAULT,
+ **kw: Any,
+ ) -> Dict[TableKey, List[ReflectedForeignKeyConstraint]]:
+ r"""Return information about foreign_keys in all tables
+ in the given schema.
+
+ The tables can be filtered by passing the names to use to
+ ``filter_names``.
+
+ The foreign key informations as described in
+ :meth:`Inspector.get_foreign_keys`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param filter_names: optionally return information only for the
+ objects listed here.
+
+ :param kind: a :class:`.ObjectKind` that specifies the type of objects
+ to reflect. Defaults to ``ObjectKind.TABLE``.
+
+ :param scope: a :class:`.ObjectScope` that specifies if foreign keys of
+ default, temporary or any tables should be reflected.
+ Defaults to ``ObjectScope.DEFAULT``.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary where the keys are two-tuple schema,table-name
+ and the values are list of dictionaries, each representing
+ a foreign key definition.
+ The schema is ``None`` if no schema is provided.
+
+ .. versionadded:: 2.0
+ """
+
+ with self._operation_context() as conn:
+ return dict(
+ self.dialect.get_multi_foreign_keys(
+ conn,
+ schema=schema,
+ filter_names=filter_names,
+ kind=kind,
+ scope=scope,
+ info_cache=self.info_cache,
+ **kw,
+ )
+ )
- Given a string `table_name` and an optional string `schema`, return
+ def get_indexes(
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> List[ReflectedIndex]:
+ r"""Return information about indexes in ``table_name``.
+
+ Given a string ``table_name`` and an optional string `schema`, return
index information as a list of dicts with these keys:
* ``name`` -
@@ -598,6 +1186,13 @@ class Inspector(inspection.Inspectable["Inspector"]):
of the database connection. For special quoting,
use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a list of dictionaries, each representing the
+ definition of an index.
+
"""
with self._operation_context() as conn:
@@ -605,10 +1200,71 @@ class Inspector(inspection.Inspectable["Inspector"]):
conn, table_name, schema, info_cache=self.info_cache, **kw
)
- def get_unique_constraints(self, table_name, schema=None, **kw):
- """Return information about unique constraints in `table_name`.
+ def get_multi_indexes(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Sequence[str]] = None,
+ kind: ObjectKind = ObjectKind.TABLE,
+ scope: ObjectScope = ObjectScope.DEFAULT,
+ **kw: Any,
+ ) -> Dict[TableKey, List[ReflectedIndex]]:
+ r"""Return information about indexes in in all objects
+ in the given schema.
+
+ The objects can be filtered by passing the names to use to
+ ``filter_names``.
+
+ The foreign key information is as described in
+ :meth:`Inspector.get_foreign_keys`.
+
+ The indexes information as described in
+ :meth:`Inspector.get_indexes`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param filter_names: optionally return information only for the
+ objects listed here.
+
+ :param kind: a :class:`.ObjectKind` that specifies the type of objects
+ to reflect. Defaults to ``ObjectKind.TABLE``.
+
+ :param scope: a :class:`.ObjectScope` that specifies if indexes of
+ default, temporary or any tables should be reflected.
+ Defaults to ``ObjectScope.DEFAULT``.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary where the keys are two-tuple schema,table-name
+ and the values are list of dictionaries, each representing the
+ definition of an index.
+ The schema is ``None`` if no schema is provided.
+
+ .. versionadded:: 2.0
+ """
+
+ with self._operation_context() as conn:
+ return dict(
+ self.dialect.get_multi_indexes(
+ conn,
+ schema=schema,
+ filter_names=filter_names,
+ kind=kind,
+ scope=scope,
+ info_cache=self.info_cache,
+ **kw,
+ )
+ )
+
+ def get_unique_constraints(
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> List[ReflectedUniqueConstraint]:
+ r"""Return information about unique constraints in ``table_name``.
- Given a string `table_name` and an optional string `schema`, return
+ Given a string ``table_name`` and an optional string `schema`, return
unique constraint information as a list of dicts with these keys:
* ``name`` -
@@ -624,6 +1280,13 @@ class Inspector(inspection.Inspectable["Inspector"]):
of the database connection. For special quoting,
use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a list of dictionaries, each representing the
+ definition of an unique constraint.
+
"""
with self._operation_context() as conn:
@@ -631,8 +1294,66 @@ class Inspector(inspection.Inspectable["Inspector"]):
conn, table_name, schema, info_cache=self.info_cache, **kw
)
- def get_table_comment(self, table_name, schema=None, **kw):
- """Return information about the table comment for ``table_name``.
+ def get_multi_unique_constraints(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Sequence[str]] = None,
+ kind: ObjectKind = ObjectKind.TABLE,
+ scope: ObjectScope = ObjectScope.DEFAULT,
+ **kw: Any,
+ ) -> Dict[TableKey, List[ReflectedUniqueConstraint]]:
+ r"""Return information about unique constraints in all tables
+ in the given schema.
+
+ The tables can be filtered by passing the names to use to
+ ``filter_names``.
+
+ The unique constraint information is as described in
+ :meth:`Inspector.get_unique_constraints`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param filter_names: optionally return information only for the
+ objects listed here.
+
+ :param kind: a :class:`.ObjectKind` that specifies the type of objects
+ to reflect. Defaults to ``ObjectKind.TABLE``.
+
+ :param scope: a :class:`.ObjectScope` that specifies if constraints of
+ default, temporary or any tables should be reflected.
+ Defaults to ``ObjectScope.DEFAULT``.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary where the keys are two-tuple schema,table-name
+ and the values are list of dictionaries, each representing the
+ definition of an unique constraint.
+ The schema is ``None`` if no schema is provided.
+
+ .. versionadded:: 2.0
+ """
+
+ with self._operation_context() as conn:
+ return dict(
+ self.dialect.get_multi_unique_constraints(
+ conn,
+ schema=schema,
+ filter_names=filter_names,
+ kind=kind,
+ scope=scope,
+ info_cache=self.info_cache,
+ **kw,
+ )
+ )
+
+ def get_table_comment(
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> ReflectedTableComment:
+ r"""Return information about the table comment for ``table_name``.
Given a string ``table_name`` and an optional string ``schema``,
return table comment information as a dictionary with these keys:
@@ -643,8 +1364,20 @@ class Inspector(inspection.Inspectable["Inspector"]):
Raises ``NotImplementedError`` for a dialect that does not support
comments.
- .. versionadded:: 1.2
+ :param table_name: string name of the table. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary, with the table comment.
+ .. versionadded:: 1.2
"""
with self._operation_context() as conn:
@@ -652,10 +1385,71 @@ class Inspector(inspection.Inspectable["Inspector"]):
conn, table_name, schema, info_cache=self.info_cache, **kw
)
- def get_check_constraints(self, table_name, schema=None, **kw):
- """Return information about check constraints in `table_name`.
+ def get_multi_table_comment(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Sequence[str]] = None,
+ kind: ObjectKind = ObjectKind.TABLE,
+ scope: ObjectScope = ObjectScope.DEFAULT,
+ **kw: Any,
+ ) -> Dict[TableKey, ReflectedTableComment]:
+ r"""Return information about the table comment in all objects
+ in the given schema.
+
+ The objects can be filtered by passing the names to use to
+ ``filter_names``.
+
+ The comment information is as described in
+ :meth:`Inspector.get_table_comment`.
+
+ Raises ``NotImplementedError`` for a dialect that does not support
+ comments.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param filter_names: optionally return information only for the
+ objects listed here.
+
+ :param kind: a :class:`.ObjectKind` that specifies the type of objects
+ to reflect. Defaults to ``ObjectKind.TABLE``.
+
+ :param scope: a :class:`.ObjectScope` that specifies if comments of
+ default, temporary or any tables should be reflected.
+ Defaults to ``ObjectScope.DEFAULT``.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary where the keys are two-tuple schema,table-name
+ and the values are dictionaries, representing the
+ table comments.
+ The schema is ``None`` if no schema is provided.
+
+ .. versionadded:: 2.0
+ """
+
+ with self._operation_context() as conn:
+ return dict(
+ self.dialect.get_multi_table_comment(
+ conn,
+ schema=schema,
+ filter_names=filter_names,
+ kind=kind,
+ scope=scope,
+ info_cache=self.info_cache,
+ **kw,
+ )
+ )
+
+ def get_check_constraints(
+ self, table_name: str, schema: Optional[str] = None, **kw: Any
+ ) -> List[ReflectedCheckConstraint]:
+ r"""Return information about check constraints in ``table_name``.
- Given a string `table_name` and an optional string `schema`, return
+ Given a string ``table_name`` and an optional string `schema`, return
check constraint information as a list of dicts with these keys:
* ``name`` -
@@ -677,6 +1471,13 @@ class Inspector(inspection.Inspectable["Inspector"]):
of the database connection. For special quoting,
use :class:`.quoted_name`.
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a list of dictionaries, each representing the
+ definition of a check constraints.
+
.. versionadded:: 1.1.0
"""
@@ -686,14 +1487,71 @@ class Inspector(inspection.Inspectable["Inspector"]):
conn, table_name, schema, info_cache=self.info_cache, **kw
)
+ def get_multi_check_constraints(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Sequence[str]] = None,
+ kind: ObjectKind = ObjectKind.TABLE,
+ scope: ObjectScope = ObjectScope.DEFAULT,
+ **kw: Any,
+ ) -> Dict[TableKey, List[ReflectedCheckConstraint]]:
+ r"""Return information about check constraints in all tables
+ in the given schema.
+
+ The tables can be filtered by passing the names to use to
+ ``filter_names``.
+
+ The check constraint information is as described in
+ :meth:`Inspector.get_check_constraints`.
+
+ :param schema: string schema name; if omitted, uses the default schema
+ of the database connection. For special quoting,
+ use :class:`.quoted_name`.
+
+ :param filter_names: optionally return information only for the
+ objects listed here.
+
+ :param kind: a :class:`.ObjectKind` that specifies the type of objects
+ to reflect. Defaults to ``ObjectKind.TABLE``.
+
+ :param scope: a :class:`.ObjectScope` that specifies if constraints of
+ default, temporary or any tables should be reflected.
+ Defaults to ``ObjectScope.DEFAULT``.
+
+ :param \**kw: Additional keyword argument to pass to the dialect
+ specific implementation. See the documentation of the dialect
+ in use for more information.
+
+ :return: a dictionary where the keys are two-tuple schema,table-name
+ and the values are list of dictionaries, each representing the
+ definition of a check constraints.
+ The schema is ``None`` if no schema is provided.
+
+ .. versionadded:: 2.0
+ """
+
+ with self._operation_context() as conn:
+ return dict(
+ self.dialect.get_multi_check_constraints(
+ conn,
+ schema=schema,
+ filter_names=filter_names,
+ kind=kind,
+ scope=scope,
+ info_cache=self.info_cache,
+ **kw,
+ )
+ )
+
def reflect_table(
self,
- table,
- include_columns,
- exclude_columns=(),
- resolve_fks=True,
- _extend_on=None,
- ):
+ table: sa_schema.Table,
+ include_columns: Optional[Collection[str]],
+ exclude_columns: Collection[str] = (),
+ resolve_fks: bool = True,
+ _extend_on: Optional[Set[sa_schema.Table]] = None,
+ _reflect_info: Optional[_ReflectionInfo] = None,
+ ) -> None:
"""Given a :class:`_schema.Table` object, load its internal
constructs based on introspection.
@@ -741,21 +1599,34 @@ class Inspector(inspection.Inspectable["Inspector"]):
if k in table.dialect_kwargs
)
+ table_key = (schema, table_name)
+ if _reflect_info is None or table_key not in _reflect_info.columns:
+ _reflect_info = self._get_reflection_info(
+ schema,
+ filter_names=[table_name],
+ kind=ObjectKind.ANY,
+ scope=ObjectScope.ANY,
+ _reflect_info=_reflect_info,
+ **table.dialect_kwargs,
+ )
+ if table_key in _reflect_info.unreflectable:
+ raise _reflect_info.unreflectable[table_key]
+
+ if table_key not in _reflect_info.columns:
+ raise exc.NoSuchTableError(table_name)
+
# reflect table options, like mysql_engine
- tbl_opts = self.get_table_options(
- table_name, schema, **table.dialect_kwargs
- )
- if tbl_opts:
- # add additional kwargs to the Table if the dialect
- # returned them
- table._validate_dialect_kwargs(tbl_opts)
+ if _reflect_info.table_options:
+ tbl_opts = _reflect_info.table_options.get(table_key)
+ if tbl_opts:
+ # add additional kwargs to the Table if the dialect
+ # returned them
+ table._validate_dialect_kwargs(tbl_opts)
found_table = False
- cols_by_orig_name = {}
+ cols_by_orig_name: Dict[str, sa_schema.Column[Any]] = {}
- for col_d in self.get_columns(
- table_name, schema, **table.dialect_kwargs
- ):
+ for col_d in _reflect_info.columns[table_key]:
found_table = True
self._reflect_column(
@@ -771,12 +1642,12 @@ class Inspector(inspection.Inspectable["Inspector"]):
raise exc.NoSuchTableError(table_name)
self._reflect_pk(
- table_name, schema, table, cols_by_orig_name, exclude_columns
+ _reflect_info, table_key, table, cols_by_orig_name, exclude_columns
)
self._reflect_fk(
- table_name,
- schema,
+ _reflect_info,
+ table_key,
table,
cols_by_orig_name,
include_columns,
@@ -787,8 +1658,8 @@ class Inspector(inspection.Inspectable["Inspector"]):
)
self._reflect_indexes(
- table_name,
- schema,
+ _reflect_info,
+ table_key,
table,
cols_by_orig_name,
include_columns,
@@ -797,8 +1668,8 @@ class Inspector(inspection.Inspectable["Inspector"]):
)
self._reflect_unique_constraints(
- table_name,
- schema,
+ _reflect_info,
+ table_key,
table,
cols_by_orig_name,
include_columns,
@@ -807,8 +1678,8 @@ class Inspector(inspection.Inspectable["Inspector"]):
)
self._reflect_check_constraints(
- table_name,
- schema,
+ _reflect_info,
+ table_key,
table,
cols_by_orig_name,
include_columns,
@@ -817,17 +1688,27 @@ class Inspector(inspection.Inspectable["Inspector"]):
)
self._reflect_table_comment(
- table_name, schema, table, reflection_options
+ _reflect_info,
+ table_key,
+ table,
+ reflection_options,
)
def _reflect_column(
- self, table, col_d, include_columns, exclude_columns, cols_by_orig_name
- ):
+ self,
+ table: sa_schema.Table,
+ col_d: ReflectedColumn,
+ include_columns: Optional[Collection[str]],
+ exclude_columns: Collection[str],
+ cols_by_orig_name: Dict[str, sa_schema.Column[Any]],
+ ) -> None:
orig_name = col_d["name"]
table.metadata.dispatch.column_reflect(self, table, col_d)
- table.dispatch.column_reflect(self, table, col_d)
+ table.dispatch.column_reflect( # type: ignore[attr-defined]
+ self, table, col_d
+ )
# fetch name again as column_reflect is allowed to
# change it
@@ -840,7 +1721,7 @@ class Inspector(inspection.Inspectable["Inspector"]):
coltype = col_d["type"]
col_kw = dict(
- (k, col_d[k])
+ (k, col_d[k]) # type: ignore[literal-required]
for k in [
"nullable",
"autoincrement",
@@ -856,15 +1737,20 @@ class Inspector(inspection.Inspectable["Inspector"]):
col_kw.update(col_d["dialect_options"])
colargs = []
+ default: Any
if col_d.get("default") is not None:
- default = col_d["default"]
- if isinstance(default, sql.elements.TextClause):
- default = sa_schema.DefaultClause(default, _reflected=True)
- elif not isinstance(default, sa_schema.FetchedValue):
+ default_text = col_d["default"]
+ assert default_text is not None
+ if isinstance(default_text, TextClause):
default = sa_schema.DefaultClause(
- sql.text(col_d["default"]), _reflected=True
+ default_text, _reflected=True
)
-
+ elif not isinstance(default_text, sa_schema.FetchedValue):
+ default = sa_schema.DefaultClause(
+ sql.text(default_text), _reflected=True
+ )
+ else:
+ default = default_text
colargs.append(default)
if "computed" in col_d:
@@ -872,11 +1758,8 @@ class Inspector(inspection.Inspectable["Inspector"]):
colargs.append(computed)
if "identity" in col_d:
- computed = sa_schema.Identity(**col_d["identity"])
- colargs.append(computed)
-
- if "sequence" in col_d:
- self._reflect_col_sequence(col_d, colargs)
+ identity = sa_schema.Identity(**col_d["identity"])
+ colargs.append(identity)
cols_by_orig_name[orig_name] = col = sa_schema.Column(
name, coltype, *colargs, **col_kw
@@ -886,23 +1769,15 @@ class Inspector(inspection.Inspectable["Inspector"]):
col.primary_key = True
table.append_column(col, replace_existing=True)
- def _reflect_col_sequence(self, col_d, colargs):
- if "sequence" in col_d:
- # TODO: mssql is using this.
- seq = col_d["sequence"]
- sequence = sa_schema.Sequence(seq["name"], 1, 1)
- if "start" in seq:
- sequence.start = seq["start"]
- if "increment" in seq:
- sequence.increment = seq["increment"]
- colargs.append(sequence)
-
def _reflect_pk(
- self, table_name, schema, table, cols_by_orig_name, exclude_columns
- ):
- pk_cons = self.get_pk_constraint(
- table_name, schema, **table.dialect_kwargs
- )
+ self,
+ _reflect_info: _ReflectionInfo,
+ table_key: TableKey,
+ table: sa_schema.Table,
+ cols_by_orig_name: Dict[str, sa_schema.Column[Any]],
+ exclude_columns: Collection[str],
+ ) -> None:
+ pk_cons = _reflect_info.pk_constraint.get(table_key)
if pk_cons:
pk_cols = [
cols_by_orig_name[pk]
@@ -919,19 +1794,17 @@ class Inspector(inspection.Inspectable["Inspector"]):
def _reflect_fk(
self,
- table_name,
- schema,
- table,
- cols_by_orig_name,
- include_columns,
- exclude_columns,
- resolve_fks,
- _extend_on,
- reflection_options,
- ):
- fkeys = self.get_foreign_keys(
- table_name, schema, **table.dialect_kwargs
- )
+ _reflect_info: _ReflectionInfo,
+ table_key: TableKey,
+ table: sa_schema.Table,
+ cols_by_orig_name: Dict[str, sa_schema.Column[Any]],
+ include_columns: Optional[Collection[str]],
+ exclude_columns: Collection[str],
+ resolve_fks: bool,
+ _extend_on: Optional[Set[sa_schema.Table]],
+ reflection_options: Dict[str, Any],
+ ) -> None:
+ fkeys = _reflect_info.foreign_keys.get(table_key, [])
for fkey_d in fkeys:
conname = fkey_d["name"]
# look for columns by orig name in cols_by_orig_name,
@@ -963,6 +1836,7 @@ class Inspector(inspection.Inspectable["Inspector"]):
schema=referred_schema,
autoload_with=self.bind,
_extend_on=_extend_on,
+ _reflect_info=_reflect_info,
**reflection_options,
)
for column in referred_columns:
@@ -977,6 +1851,7 @@ class Inspector(inspection.Inspectable["Inspector"]):
autoload_with=self.bind,
schema=sa_schema.BLANK_SCHEMA,
_extend_on=_extend_on,
+ _reflect_info=_reflect_info,
**reflection_options,
)
for column in referred_columns:
@@ -1005,16 +1880,16 @@ class Inspector(inspection.Inspectable["Inspector"]):
def _reflect_indexes(
self,
- table_name,
- schema,
- table,
- cols_by_orig_name,
- include_columns,
- exclude_columns,
- reflection_options,
- ):
+ _reflect_info: _ReflectionInfo,
+ table_key: TableKey,
+ table: sa_schema.Table,
+ cols_by_orig_name: Dict[str, sa_schema.Column[Any]],
+ include_columns: Optional[Collection[str]],
+ exclude_columns: Collection[str],
+ reflection_options: Dict[str, Any],
+ ) -> None:
# Indexes
- indexes = self.get_indexes(table_name, schema)
+ indexes = _reflect_info.indexes.get(table_key, [])
for index_d in indexes:
name = index_d["name"]
columns = index_d["column_names"]
@@ -1034,6 +1909,7 @@ class Inspector(inspection.Inspectable["Inspector"]):
continue
# look for columns by orig name in cols_by_orig_name,
# but support columns that are in-Python only as fallback
+ idx_col: Any
idx_cols = []
for c in columns:
try:
@@ -1045,7 +1921,7 @@ class Inspector(inspection.Inspectable["Inspector"]):
except KeyError:
util.warn(
"%s key '%s' was not located in "
- "columns for table '%s'" % (flavor, c, table_name)
+ "columns for table '%s'" % (flavor, c, table.name)
)
continue
c_sorting = column_sorting.get(c, ())
@@ -1063,22 +1939,16 @@ class Inspector(inspection.Inspectable["Inspector"]):
def _reflect_unique_constraints(
self,
- table_name,
- schema,
- table,
- cols_by_orig_name,
- include_columns,
- exclude_columns,
- reflection_options,
- ):
-
+ _reflect_info: _ReflectionInfo,
+ table_key: TableKey,
+ table: sa_schema.Table,
+ cols_by_orig_name: Dict[str, sa_schema.Column[Any]],
+ include_columns: Optional[Collection[str]],
+ exclude_columns: Collection[str],
+ reflection_options: Dict[str, Any],
+ ) -> None:
+ constraints = _reflect_info.unique_constraints.get(table_key, [])
# Unique Constraints
- try:
- constraints = self.get_unique_constraints(table_name, schema)
- except NotImplementedError:
- # optional dialect feature
- return
-
for const_d in constraints:
conname = const_d["name"]
columns = const_d["column_names"]
@@ -1104,7 +1974,7 @@ class Inspector(inspection.Inspectable["Inspector"]):
except KeyError:
util.warn(
"unique constraint key '%s' was not located in "
- "columns for table '%s'" % (c, table_name)
+ "columns for table '%s'" % (c, table.name)
)
else:
constrained_cols.append(constrained_col)
@@ -1114,29 +1984,166 @@ class Inspector(inspection.Inspectable["Inspector"]):
def _reflect_check_constraints(
self,
- table_name,
- schema,
- table,
- cols_by_orig_name,
- include_columns,
- exclude_columns,
- reflection_options,
- ):
- try:
- constraints = self.get_check_constraints(table_name, schema)
- except NotImplementedError:
- # optional dialect feature
- return
-
+ _reflect_info: _ReflectionInfo,
+ table_key: TableKey,
+ table: sa_schema.Table,
+ cols_by_orig_name: Dict[str, sa_schema.Column[Any]],
+ include_columns: Optional[Collection[str]],
+ exclude_columns: Collection[str],
+ reflection_options: Dict[str, Any],
+ ) -> None:
+ constraints = _reflect_info.check_constraints.get(table_key, [])
for const_d in constraints:
table.append_constraint(sa_schema.CheckConstraint(**const_d))
def _reflect_table_comment(
- self, table_name, schema, table, reflection_options
- ):
- try:
- comment_dict = self.get_table_comment(table_name, schema)
- except NotImplementedError:
- return
+ self,
+ _reflect_info: _ReflectionInfo,
+ table_key: TableKey,
+ table: sa_schema.Table,
+ reflection_options: Dict[str, Any],
+ ) -> None:
+ comment_dict = _reflect_info.table_comment.get(table_key)
+ if comment_dict:
+ table.comment = comment_dict["text"]
+
+ def _get_reflection_info(
+ self,
+ schema: Optional[str] = None,
+ filter_names: Optional[Collection[str]] = None,
+ available: Optional[Collection[str]] = None,
+ _reflect_info: Optional[_ReflectionInfo] = None,
+ **kw: Any,
+ ) -> _ReflectionInfo:
+ kw["schema"] = schema
+
+ if filter_names and available and len(filter_names) > 100:
+ fraction = len(filter_names) / len(available)
+ else:
+ fraction = None
+
+ unreflectable: Dict[TableKey, exc.UnreflectableTableError]
+ kw["unreflectable"] = unreflectable = {}
+
+ has_result: bool = True
+
+ def run(
+ meth: Any,
+ *,
+ optional: bool = False,
+ check_filter_names_from_meth: bool = False,
+ ) -> Any:
+ nonlocal has_result
+ # simple heuristic to improve reflection performance if a
+ # dialect implements multi_reflection:
+ # if more than 50% of the tables in the db are in filter_names
+ # load all the tables, since it's most likely faster to avoid
+ # a filter on that many tables.
+ if (
+ fraction is None
+ or fraction <= 0.5
+ or not self.dialect._overrides_default(meth.__name__)
+ ):
+ _fn = filter_names
+ else:
+ _fn = None
+ try:
+ if has_result:
+ res = meth(filter_names=_fn, **kw)
+ if check_filter_names_from_meth and not res:
+ # method returned no result data.
+ # skip any future call methods
+ has_result = False
+ else:
+ res = {}
+ except NotImplementedError:
+ if not optional:
+ raise
+ res = {}
+ return res
+
+ info = _ReflectionInfo(
+ columns=run(
+ self.get_multi_columns, check_filter_names_from_meth=True
+ ),
+ pk_constraint=run(self.get_multi_pk_constraint),
+ foreign_keys=run(self.get_multi_foreign_keys),
+ indexes=run(self.get_multi_indexes),
+ unique_constraints=run(
+ self.get_multi_unique_constraints, optional=True
+ ),
+ table_comment=run(self.get_multi_table_comment, optional=True),
+ check_constraints=run(
+ self.get_multi_check_constraints, optional=True
+ ),
+ table_options=run(self.get_multi_table_options, optional=True),
+ unreflectable=unreflectable,
+ )
+ if _reflect_info:
+ _reflect_info.update(info)
+ return _reflect_info
else:
- table.comment = comment_dict.get("text", None)
+ return info
+
+
+@final
+class ReflectionDefaults:
+ """provides blank default values for reflection methods."""
+
+ @classmethod
+ def columns(cls) -> List[ReflectedColumn]:
+ return []
+
+ @classmethod
+ def pk_constraint(cls) -> ReflectedPrimaryKeyConstraint:
+ return { # type: ignore # pep-655 not supported
+ "name": None,
+ "constrained_columns": [],
+ }
+
+ @classmethod
+ def foreign_keys(cls) -> List[ReflectedForeignKeyConstraint]:
+ return []
+
+ @classmethod
+ def indexes(cls) -> List[ReflectedIndex]:
+ return []
+
+ @classmethod
+ def unique_constraints(cls) -> List[ReflectedUniqueConstraint]:
+ return []
+
+ @classmethod
+ def check_constraints(cls) -> List[ReflectedCheckConstraint]:
+ return []
+
+ @classmethod
+ def table_options(cls) -> Dict[str, Any]:
+ return {}
+
+ @classmethod
+ def table_comment(cls) -> ReflectedTableComment:
+ return {"text": None}
+
+
+@dataclass
+class _ReflectionInfo:
+ columns: Dict[TableKey, List[ReflectedColumn]]
+ pk_constraint: Dict[TableKey, Optional[ReflectedPrimaryKeyConstraint]]
+ foreign_keys: Dict[TableKey, List[ReflectedForeignKeyConstraint]]
+ indexes: Dict[TableKey, List[ReflectedIndex]]
+ # optionals
+ unique_constraints: Dict[TableKey, List[ReflectedUniqueConstraint]]
+ table_comment: Dict[TableKey, Optional[ReflectedTableComment]]
+ check_constraints: Dict[TableKey, List[ReflectedCheckConstraint]]
+ table_options: Dict[TableKey, Dict[str, Any]]
+ unreflectable: Dict[TableKey, exc.UnreflectableTableError]
+
+ def update(self, other: _ReflectionInfo) -> None:
+ for k, v in self.__dict__.items():
+ ov = getattr(other, k)
+ if ov is not None:
+ if v is None:
+ setattr(self, k, ov)
+ else:
+ v.update(ov)
diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py
index 391f74772..70c01d8d3 100644
--- a/lib/sqlalchemy/sql/base.py
+++ b/lib/sqlalchemy/sql/base.py
@@ -536,7 +536,7 @@ class DialectKWArgs:
util.portable_instancemethod(self._kw_reg_for_dialect_cls)
)
- def _validate_dialect_kwargs(self, kwargs: Any) -> None:
+ def _validate_dialect_kwargs(self, kwargs: Dict[str, Any]) -> None:
# validate remaining kwargs that they all specify DB prefixes
if not kwargs:
diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py
index c16fbdae1..5922c2db0 100644
--- a/lib/sqlalchemy/sql/cache_key.py
+++ b/lib/sqlalchemy/sql/cache_key.py
@@ -12,6 +12,7 @@ from itertools import zip_longest
import typing
from typing import Any
from typing import Dict
+from typing import Iterable
from typing import Iterator
from typing import List
from typing import MutableMapping
@@ -546,6 +547,43 @@ class CacheKey(NamedTuple):
return target_element.params(translate)
+def _ad_hoc_cache_key_from_args(
+ tokens: Tuple[Any, ...],
+ traverse_args: Iterable[Tuple[str, InternalTraversal]],
+ args: Iterable[Any],
+) -> Tuple[Any, ...]:
+ """a quick cache key generator used by reflection.flexi_cache."""
+ bindparams: List[BindParameter[Any]] = []
+
+ _anon_map = anon_map()
+
+ tup = tokens
+
+ for (attrname, sym), arg in zip(traverse_args, args):
+ key = sym.name
+ visit_key = key.replace("dp_", "visit_")
+
+ if arg is None:
+ tup += (attrname, None)
+ continue
+
+ meth = getattr(_cache_key_traversal_visitor, visit_key)
+ if meth is CACHE_IN_PLACE:
+ tup += (attrname, arg)
+ elif meth in (
+ CALL_GEN_CACHE_KEY,
+ STATIC_CACHE_KEY,
+ ANON_NAME,
+ PROPAGATE_ATTRS,
+ ):
+ raise NotImplementedError(
+ f"Haven't implemented symbol {meth} for ad-hoc key from args"
+ )
+ else:
+ tup += meth(attrname, arg, None, _anon_map, bindparams)
+ return tup
+
+
class _CacheKeyTraversal(HasTraversalDispatch):
# very common elements are inlined into the main _get_cache_key() method
# to produce a dramatic savings in Python function call overhead
diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py
index c37b60003..1c4b3b0ce 100644
--- a/lib/sqlalchemy/sql/schema.py
+++ b/lib/sqlalchemy/sql/schema.py
@@ -38,6 +38,7 @@ import typing
from typing import Any
from typing import Callable
from typing import cast
+from typing import Collection
from typing import Dict
from typing import Iterable
from typing import Iterator
@@ -99,6 +100,7 @@ if typing.TYPE_CHECKING:
from ..engine.interfaces import _ExecuteOptionsParameter
from ..engine.interfaces import ExecutionContext
from ..engine.mock import MockConnection
+ from ..engine.reflection import _ReflectionInfo
from ..sql.selectable import FromClause
_T = TypeVar("_T", bound="Any")
@@ -493,7 +495,7 @@ class Table(
keep_existing: bool = False,
extend_existing: bool = False,
resolve_fks: bool = True,
- include_columns: Optional[Iterable[str]] = None,
+ include_columns: Optional[Collection[str]] = None,
implicit_returning: bool = True,
comment: Optional[str] = None,
info: Optional[Dict[Any, Any]] = None,
@@ -829,6 +831,7 @@ class Table(
self.fullname = self.name
self.implicit_returning = implicit_returning
+ _reflect_info = kw.pop("_reflect_info", None)
self.comment = comment
@@ -852,6 +855,7 @@ class Table(
autoload_with,
include_columns,
_extend_on=_extend_on,
+ _reflect_info=_reflect_info,
resolve_fks=resolve_fks,
)
@@ -869,10 +873,11 @@ class Table(
self,
metadata: MetaData,
autoload_with: Union[Engine, Connection],
- include_columns: Optional[Iterable[str]],
- exclude_columns: Iterable[str] = (),
+ include_columns: Optional[Collection[str]],
+ exclude_columns: Collection[str] = (),
resolve_fks: bool = True,
_extend_on: Optional[Set[Table]] = None,
+ _reflect_info: _ReflectionInfo | None = None,
) -> None:
insp = inspection.inspect(autoload_with)
with insp._inspection_context() as conn_insp:
@@ -882,6 +887,7 @@ class Table(
exclude_columns,
resolve_fks,
_extend_on=_extend_on,
+ _reflect_info=_reflect_info,
)
@property
@@ -924,6 +930,7 @@ class Table(
autoload_replace = kwargs.pop("autoload_replace", True)
schema = kwargs.pop("schema", None)
_extend_on = kwargs.pop("_extend_on", None)
+ _reflect_info = kwargs.pop("_reflect_info", None)
# these arguments are only used with _init()
kwargs.pop("extend_existing", False)
kwargs.pop("keep_existing", False)
@@ -972,6 +979,7 @@ class Table(
exclude_columns,
resolve_fks,
_extend_on=_extend_on,
+ _reflect_info=_reflect_info,
)
self._extra_kwargs(**kwargs)
@@ -3165,7 +3173,7 @@ class IdentityOptions:
nominvalue: Optional[bool] = None,
nomaxvalue: Optional[bool] = None,
cycle: Optional[bool] = None,
- cache: Optional[bool] = None,
+ cache: Optional[int] = None,
order: Optional[bool] = None,
) -> None:
"""Construct a :class:`.IdentityOptions` object.
@@ -5130,6 +5138,7 @@ class MetaData(HasSchemaAttr):
sorted(self.tables.values(), key=lambda t: t.key) # type: ignore
)
+ @util.preload_module("sqlalchemy.engine.reflection")
def reflect(
self,
bind: Union[Engine, Connection],
@@ -5159,7 +5168,7 @@ class MetaData(HasSchemaAttr):
is used, if any.
:param views:
- If True, also reflect views.
+ If True, also reflect views (materialized and plain).
:param only:
Optional. Load only a sub-set of available named tables. May be
@@ -5225,7 +5234,7 @@ class MetaData(HasSchemaAttr):
"""
with inspection.inspect(bind)._inspection_context() as insp:
- reflect_opts = {
+ reflect_opts: Any = {
"autoload_with": insp,
"extend_existing": extend_existing,
"autoload_replace": autoload_replace,
@@ -5241,15 +5250,21 @@ class MetaData(HasSchemaAttr):
if schema is not None:
reflect_opts["schema"] = schema
+ kind = util.preloaded.engine_reflection.ObjectKind.TABLE
available: util.OrderedSet[str] = util.OrderedSet(
insp.get_table_names(schema)
)
if views:
+ kind = util.preloaded.engine_reflection.ObjectKind.ANY
available.update(insp.get_view_names(schema))
+ try:
+ available.update(insp.get_materialized_view_names(schema))
+ except NotImplementedError:
+ pass
if schema is not None:
available_w_schema: util.OrderedSet[str] = util.OrderedSet(
- ["%s.%s" % (schema, name) for name in available]
+ [f"{schema}.{name}" for name in available]
)
else:
available_w_schema = available
@@ -5282,6 +5297,17 @@ class MetaData(HasSchemaAttr):
for name in only
if extend_existing or name not in current
]
+ # pass the available tables so the inspector can
+ # choose to ignore the filter_names
+ _reflect_info = insp._get_reflection_info(
+ schema=schema,
+ filter_names=load,
+ available=available,
+ kind=kind,
+ scope=util.preloaded.engine_reflection.ObjectScope.ANY,
+ **dialect_kwargs,
+ )
+ reflect_opts["_reflect_info"] = _reflect_info
for name in load:
try:
@@ -5489,7 +5515,7 @@ class Identity(IdentityOptions, FetchedValue, SchemaItem):
nominvalue: Optional[bool] = None,
nomaxvalue: Optional[bool] = None,
cycle: Optional[bool] = None,
- cache: Optional[bool] = None,
+ cache: Optional[int] = None,
order: Optional[bool] = None,
) -> None:
"""Construct a GENERATED { ALWAYS | BY DEFAULT } AS IDENTITY DDL
diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py
index 9888d7c18..937706363 100644
--- a/lib/sqlalchemy/testing/assertions.py
+++ b/lib/sqlalchemy/testing/assertions.py
@@ -644,13 +644,21 @@ class AssertsCompiledSQL:
class ComparesTables:
- def assert_tables_equal(self, table, reflected_table, strict_types=False):
+ def assert_tables_equal(
+ self,
+ table,
+ reflected_table,
+ strict_types=False,
+ strict_constraints=True,
+ ):
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
eq_(c.name, reflected_c.name)
assert reflected_c is reflected_table.c[c.name]
- eq_(c.primary_key, reflected_c.primary_key)
- eq_(c.nullable, reflected_c.nullable)
+
+ if strict_constraints:
+ eq_(c.primary_key, reflected_c.primary_key)
+ eq_(c.nullable, reflected_c.nullable)
if strict_types:
msg = "Type '%s' doesn't correspond to type '%s'"
@@ -664,18 +672,20 @@ class ComparesTables:
if isinstance(c.type, sqltypes.String):
eq_(c.type.length, reflected_c.type.length)
- eq_(
- {f.column.name for f in c.foreign_keys},
- {f.column.name for f in reflected_c.foreign_keys},
- )
+ if strict_constraints:
+ eq_(
+ {f.column.name for f in c.foreign_keys},
+ {f.column.name for f in reflected_c.foreign_keys},
+ )
if c.server_default:
assert isinstance(
reflected_c.server_default, schema.FetchedValue
)
- assert len(table.primary_key) == len(reflected_table.primary_key)
- for c in table.primary_key:
- assert reflected_table.primary_key.columns[c.name] is not None
+ if strict_constraints:
+ assert len(table.primary_key) == len(reflected_table.primary_key)
+ for c in table.primary_key:
+ assert reflected_table.primary_key.columns[c.name] is not None
def assert_types_base(self, c1, c2):
assert c1.type._compare_type_affinity(
diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py
index fa7d2ca19..cea07b305 100644
--- a/lib/sqlalchemy/testing/plugin/pytestplugin.py
+++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py
@@ -741,13 +741,18 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
fn._sa_parametrize.append((argnames, pytest_params))
return fn
else:
+ _fn_argnames = inspect.getfullargspec(fn).args[1:]
if argnames is None:
- _argnames = inspect.getfullargspec(fn).args[1:]
+ _argnames = _fn_argnames
else:
_argnames = re.split(r", *", argnames)
if has_exclusions:
- _argnames += ["_exclusions"]
+ existing_exl = sum(
+ 1 for n in _fn_argnames if n.startswith("_exclusions")
+ )
+ current_exclusion_name = f"_exclusions_{existing_exl}"
+ _argnames += [current_exclusion_name]
@_pytest_fn_decorator
def check_exclusions(fn, *args, **kw):
@@ -755,13 +760,10 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions):
if _exclusions:
exlu = exclusions.compound().add(*_exclusions)
fn = exlu(fn)
- return fn(*args[0:-1], **kw)
-
- def process_metadata(spec):
- spec.args.append("_exclusions")
+ return fn(*args[:-1], **kw)
fn = check_exclusions(
- fn, add_positional_parameters=("_exclusions",)
+ fn, add_positional_parameters=(current_exclusion_name,)
)
return pytest.mark.parametrize(_argnames, pytest_params)(fn)
diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py
index d38437732..498d92a77 100644
--- a/lib/sqlalchemy/testing/provision.py
+++ b/lib/sqlalchemy/testing/provision.py
@@ -230,7 +230,39 @@ def drop_all_schema_objects(cfg, eng):
drop_all_schema_objects_pre_tables(cfg, eng)
+ drop_views(cfg, eng)
+
+ if config.requirements.materialized_views.enabled:
+ drop_materialized_views(cfg, eng)
+
inspector = inspect(eng)
+
+ consider_schemas = (None,)
+ if config.requirements.schemas.enabled_for_config(cfg):
+ consider_schemas += (cfg.test_schema, cfg.test_schema_2)
+ util.drop_all_tables(eng, inspector, consider_schemas=consider_schemas)
+
+ drop_all_schema_objects_post_tables(cfg, eng)
+
+ if config.requirements.sequences.enabled_for_config(cfg):
+ with eng.begin() as conn:
+ for seq in inspector.get_sequence_names():
+ conn.execute(ddl.DropSequence(schema.Sequence(seq)))
+ if config.requirements.schemas.enabled_for_config(cfg):
+ for schema_name in [cfg.test_schema, cfg.test_schema_2]:
+ for seq in inspector.get_sequence_names(
+ schema=schema_name
+ ):
+ conn.execute(
+ ddl.DropSequence(
+ schema.Sequence(seq, schema=schema_name)
+ )
+ )
+
+
+def drop_views(cfg, eng):
+ inspector = inspect(eng)
+
try:
view_names = inspector.get_view_names()
except NotImplementedError:
@@ -244,7 +276,7 @@ def drop_all_schema_objects(cfg, eng):
if config.requirements.schemas.enabled_for_config(cfg):
try:
- view_names = inspector.get_view_names(schema="test_schema")
+ view_names = inspector.get_view_names(schema=cfg.test_schema)
except NotImplementedError:
pass
else:
@@ -255,32 +287,30 @@ def drop_all_schema_objects(cfg, eng):
schema.Table(
vname,
schema.MetaData(),
- schema="test_schema",
+ schema=cfg.test_schema,
)
)
)
- util.drop_all_tables(eng, inspector)
- if config.requirements.schemas.enabled_for_config(cfg):
- util.drop_all_tables(eng, inspector, schema=cfg.test_schema)
- util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2)
- drop_all_schema_objects_post_tables(cfg, eng)
+def drop_materialized_views(cfg, eng):
+ inspector = inspect(eng)
- if config.requirements.sequences.enabled_for_config(cfg):
+ mview_names = inspector.get_materialized_view_names()
+
+ with eng.begin() as conn:
+ for vname in mview_names:
+ conn.exec_driver_sql(f"DROP MATERIALIZED VIEW {vname}")
+
+ if config.requirements.schemas.enabled_for_config(cfg):
+ mview_names = inspector.get_materialized_view_names(
+ schema=cfg.test_schema
+ )
with eng.begin() as conn:
- for seq in inspector.get_sequence_names():
- conn.execute(ddl.DropSequence(schema.Sequence(seq)))
- if config.requirements.schemas.enabled_for_config(cfg):
- for schema_name in [cfg.test_schema, cfg.test_schema_2]:
- for seq in inspector.get_sequence_names(
- schema=schema_name
- ):
- conn.execute(
- ddl.DropSequence(
- schema.Sequence(seq, schema=schema_name)
- )
- )
+ for vname in mview_names:
+ conn.exec_driver_sql(
+ f"DROP MATERIALIZED VIEW {cfg.test_schema}.{vname}"
+ )
@register.init
diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py
index 4f9c73cf6..038f6e9bd 100644
--- a/lib/sqlalchemy/testing/requirements.py
+++ b/lib/sqlalchemy/testing/requirements.py
@@ -65,6 +65,25 @@ class SuiteRequirements(Requirements):
return exclusions.open()
@property
+ def foreign_keys_reflect_as_index(self):
+ """Target database creates an index that's reflected for
+ foreign keys."""
+
+ return exclusions.closed()
+
+ @property
+ def unique_index_reflect_as_unique_constraints(self):
+ """Target database reflects unique indexes as unique constrains."""
+
+ return exclusions.closed()
+
+ @property
+ def unique_constraints_reflect_as_index(self):
+ """Target database reflects unique constraints as indexes."""
+
+ return exclusions.closed()
+
+ @property
def table_value_constructor(self):
"""Database / dialect supports a query like::
@@ -629,6 +648,12 @@ class SuiteRequirements(Requirements):
return self.schemas
@property
+ def schema_create_delete(self):
+ """target database supports schema create and dropped with
+ 'CREATE SCHEMA' and 'DROP SCHEMA'"""
+ return exclusions.closed()
+
+ @property
def primary_key_constraint_reflection(self):
return exclusions.open()
@@ -693,6 +718,12 @@ class SuiteRequirements(Requirements):
return exclusions.open()
@property
+ def reflect_indexes_with_ascdesc(self):
+ """target database supports reflecting INDEX with per-column
+ ASC/DESC."""
+ return exclusions.open()
+
+ @property
def indexes_with_expressions(self):
"""target database supports CREATE INDEX against SQL expressions."""
return exclusions.closed()
@@ -1567,3 +1598,18 @@ class SuiteRequirements(Requirements):
def json_deserializer_binary(self):
"indicates if the json_deserializer function is called with bytes"
return exclusions.closed()
+
+ @property
+ def reflect_table_options(self):
+ """Target database must support reflecting table_options."""
+ return exclusions.closed()
+
+ @property
+ def materialized_views(self):
+ """Target database must support MATERIALIZED VIEWs."""
+ return exclusions.closed()
+
+ @property
+ def materialized_views_reflect_pk(self):
+ """Target database reflect MATERIALIZED VIEWs pks."""
+ return exclusions.closed()
diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py
index e4a92a732..46cbf4759 100644
--- a/lib/sqlalchemy/testing/schema.py
+++ b/lib/sqlalchemy/testing/schema.py
@@ -23,7 +23,7 @@ __all__ = ["Table", "Column"]
table_options = {}
-def Table(*args, **kw):
+def Table(*args, **kw) -> schema.Table:
"""A schema.Table wrapper/hook for dialect-specific tweaks."""
test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")}
@@ -134,6 +134,19 @@ class eq_type_affinity:
return self.target._type_affinity is not other._type_affinity
+class eq_compile_type:
+ """similar to eq_type_affinity but uses compile"""
+
+ def __init__(self, target):
+ self.target = target
+
+ def __eq__(self, other):
+ return self.target == other.compile()
+
+ def __ne__(self, other):
+ return self.target != other.compile()
+
+
class eq_clause_element:
"""Helper to compare SQL structures based on compare()"""
diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py
index b09b96227..7b8e2aa8b 100644
--- a/lib/sqlalchemy/testing/suite/test_reflection.py
+++ b/lib/sqlalchemy/testing/suite/test_reflection.py
@@ -7,6 +7,8 @@ import sqlalchemy as sa
from .. import config
from .. import engines
from .. import eq_
+from .. import expect_raises
+from .. import expect_raises_message
from .. import expect_warnings
from .. import fixtures
from .. import is_
@@ -24,12 +26,19 @@ from ... import MetaData
from ... import String
from ... import testing
from ... import types as sql_types
+from ...engine import Inspector
+from ...engine import ObjectKind
+from ...engine import ObjectScope
+from ...exc import NoSuchTableError
+from ...exc import UnreflectableTableError
from ...schema import DDL
from ...schema import Index
from ...sql.elements import quoted_name
from ...sql.schema import BLANK_SCHEMA
+from ...testing import ComparesTables
from ...testing import is_false
from ...testing import is_true
+from ...testing import mock
metadata, users = None, None
@@ -61,6 +70,19 @@ class HasTableTest(fixtures.TablesTest):
is_false(config.db.dialect.has_table(conn, "test_table_s"))
is_false(config.db.dialect.has_table(conn, "nonexistent_table"))
+ def test_has_table_cache(self, metadata):
+ insp = inspect(config.db)
+ is_true(insp.has_table("test_table"))
+ nt = Table("new_table", metadata, Column("col", Integer))
+ is_false(insp.has_table("new_table"))
+ nt.create(config.db)
+ try:
+ is_false(insp.has_table("new_table"))
+ insp.clear_cache()
+ is_true(insp.has_table("new_table"))
+ finally:
+ nt.drop(config.db)
+
@testing.requires.schemas
def test_has_table_schema(self):
with config.db.begin() as conn:
@@ -117,6 +139,7 @@ class HasIndexTest(fixtures.TablesTest):
metadata,
Column("id", Integer, primary_key=True),
Column("data", String(50)),
+ Column("data2", String(50)),
)
Index("my_idx", tt.c.data)
@@ -130,40 +153,56 @@ class HasIndexTest(fixtures.TablesTest):
)
Index("my_idx_s", tt.c.data)
- def test_has_index(self):
- with config.db.begin() as conn:
- assert config.db.dialect.has_index(conn, "test_table", "my_idx")
- assert not config.db.dialect.has_index(
- conn, "test_table", "my_idx_s"
- )
- assert not config.db.dialect.has_index(
- conn, "nonexistent_table", "my_idx"
- )
- assert not config.db.dialect.has_index(
- conn, "test_table", "nonexistent_idx"
- )
+ kind = testing.combinations("dialect", "inspector", argnames="kind")
+
+ def _has_index(self, kind, conn):
+ if kind == "dialect":
+ return lambda *a, **k: config.db.dialect.has_index(conn, *a, **k)
+ else:
+ return inspect(conn).has_index
+
+ @kind
+ def test_has_index(self, kind, connection, metadata):
+ meth = self._has_index(kind, connection)
+ assert meth("test_table", "my_idx")
+ assert not meth("test_table", "my_idx_s")
+ assert not meth("nonexistent_table", "my_idx")
+ assert not meth("test_table", "nonexistent_idx")
+
+ assert not meth("test_table", "my_idx_2")
+ assert not meth("test_table_2", "my_idx_3")
+ idx = Index("my_idx_2", self.tables.test_table.c.data2)
+ tbl = Table(
+ "test_table_2",
+ metadata,
+ Column("foo", Integer),
+ Index("my_idx_3", "foo"),
+ )
+ idx.create(connection)
+ tbl.create(connection)
+ try:
+ if kind == "inspector":
+ assert not meth("test_table", "my_idx_2")
+ assert not meth("test_table_2", "my_idx_3")
+ meth.__self__.clear_cache()
+ assert meth("test_table", "my_idx_2") is True
+ assert meth("test_table_2", "my_idx_3") is True
+ finally:
+ tbl.drop(connection)
+ idx.drop(connection)
@testing.requires.schemas
- def test_has_index_schema(self):
- with config.db.begin() as conn:
- assert config.db.dialect.has_index(
- conn, "test_table", "my_idx_s", schema=config.test_schema
- )
- assert not config.db.dialect.has_index(
- conn, "test_table", "my_idx", schema=config.test_schema
- )
- assert not config.db.dialect.has_index(
- conn,
- "nonexistent_table",
- "my_idx_s",
- schema=config.test_schema,
- )
- assert not config.db.dialect.has_index(
- conn,
- "test_table",
- "nonexistent_idx_s",
- schema=config.test_schema,
- )
+ @kind
+ def test_has_index_schema(self, kind, connection):
+ meth = self._has_index(kind, connection)
+ assert meth("test_table", "my_idx_s", schema=config.test_schema)
+ assert not meth("test_table", "my_idx", schema=config.test_schema)
+ assert not meth(
+ "nonexistent_table", "my_idx_s", schema=config.test_schema
+ )
+ assert not meth(
+ "test_table", "nonexistent_idx_s", schema=config.test_schema
+ )
class QuotedNameArgumentTest(fixtures.TablesTest):
@@ -264,7 +303,12 @@ class QuotedNameArgumentTest(fixtures.TablesTest):
def test_get_table_options(self, name):
insp = inspect(config.db)
- insp.get_table_options(name)
+ if testing.requires.reflect_table_options.enabled:
+ res = insp.get_table_options(name)
+ is_true(isinstance(res, dict))
+ else:
+ with expect_raises(NotImplementedError):
+ res = insp.get_table_options(name)
@quote_fixtures
@testing.requires.view_column_reflection
@@ -311,7 +355,37 @@ class QuotedNameArgumentTest(fixtures.TablesTest):
assert insp.get_check_constraints(name)
-class ComponentReflectionTest(fixtures.TablesTest):
+def _multi_combination(fn):
+ schema = testing.combinations(
+ None,
+ (
+ lambda: config.test_schema,
+ testing.requires.schemas,
+ ),
+ argnames="schema",
+ )
+ scope = testing.combinations(
+ ObjectScope.DEFAULT,
+ ObjectScope.TEMPORARY,
+ ObjectScope.ANY,
+ argnames="scope",
+ )
+ kind = testing.combinations(
+ ObjectKind.TABLE,
+ ObjectKind.VIEW,
+ ObjectKind.MATERIALIZED_VIEW,
+ ObjectKind.ANY,
+ ObjectKind.ANY_VIEW,
+ ObjectKind.TABLE | ObjectKind.VIEW,
+ ObjectKind.TABLE | ObjectKind.MATERIALIZED_VIEW,
+ argnames="kind",
+ )
+ filter_names = testing.combinations(True, False, argnames="use_filter")
+
+ return schema(scope(kind(filter_names(fn))))
+
+
+class ComponentReflectionTest(ComparesTables, fixtures.TablesTest):
run_inserts = run_deletes = None
__backend__ = True
@@ -354,6 +428,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
"%susers.user_id" % schema_prefix, name="user_id_fk"
),
),
+ sa.CheckConstraint("test2 > 0", name="test2_gt_zero"),
schema=schema,
test_needs_fk=True,
)
@@ -364,6 +439,8 @@ class ComponentReflectionTest(fixtures.TablesTest):
Column("user_id", sa.INT, primary_key=True),
Column("test1", sa.CHAR(5), nullable=False),
Column("test2", sa.Float(), nullable=False),
+ Column("parent_user_id", sa.Integer),
+ sa.CheckConstraint("test2 > 0", name="test2_gt_zero"),
schema=schema,
test_needs_fk=True,
)
@@ -375,9 +452,19 @@ class ComponentReflectionTest(fixtures.TablesTest):
Column(
"address_id",
sa.Integer,
- sa.ForeignKey("%semail_addresses.address_id" % schema_prefix),
+ sa.ForeignKey(
+ "%semail_addresses.address_id" % schema_prefix,
+ name="email_add_id_fg",
+ ),
+ ),
+ Column("data", sa.String(30), unique=True),
+ sa.CheckConstraint(
+ "address_id > 0 AND address_id < 1000",
+ name="address_id_gt_zero",
+ ),
+ sa.UniqueConstraint(
+ "address_id", "dingaling_id", name="zz_dingalings_multiple"
),
- Column("data", sa.String(30)),
schema=schema,
test_needs_fk=True,
)
@@ -388,7 +475,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
Column(
"remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id)
),
- Column("email_address", sa.String(20)),
+ Column("email_address", sa.String(20), index=True),
sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"),
schema=schema,
test_needs_fk=True,
@@ -406,6 +493,12 @@ class ComponentReflectionTest(fixtures.TablesTest):
schema=schema,
comment=r"""the test % ' " \ table comment""",
)
+ Table(
+ "no_constraints",
+ metadata,
+ Column("data", sa.String(20)),
+ schema=schema,
+ )
if testing.requires.cross_schema_fk_reflection.enabled:
if schema is None:
@@ -449,7 +542,10 @@ class ComponentReflectionTest(fixtures.TablesTest):
)
if testing.requires.index_reflection.enabled:
- cls.define_index(metadata, users)
+ Index("users_t_idx", users.c.test1, users.c.test2, unique=True)
+ Index(
+ "users_all_idx", users.c.user_id, users.c.test2, users.c.test1
+ )
if not schema:
# test_needs_fk is at the moment to force MySQL InnoDB
@@ -468,7 +564,10 @@ class ComponentReflectionTest(fixtures.TablesTest):
test_needs_fk=True,
)
- if testing.requires.indexes_with_ascdesc.enabled:
+ if (
+ testing.requires.indexes_with_ascdesc.enabled
+ and testing.requires.reflect_indexes_with_ascdesc.enabled
+ ):
Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc())
Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc())
@@ -478,11 +577,15 @@ class ComponentReflectionTest(fixtures.TablesTest):
cls.define_temp_tables(metadata)
@classmethod
+ def temp_table_name(cls):
+ return get_temp_table_name(
+ config, config.db, f"user_tmp_{config.ident}"
+ )
+
+ @classmethod
def define_temp_tables(cls, metadata):
kw = temp_table_keyword_args(config, config.db)
- table_name = get_temp_table_name(
- config, config.db, "user_tmp_%s" % config.ident
- )
+ table_name = cls.temp_table_name()
user_tmp = Table(
table_name,
metadata,
@@ -495,7 +598,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
# unique constraints created against temp tables in different
# databases.
# https://www.arbinada.com/en/node/1645
- sa.UniqueConstraint("name", name="user_tmp_uq_%s" % config.ident),
+ sa.UniqueConstraint("name", name=f"user_tmp_uq_{config.ident}"),
sa.Index("user_tmp_ix", "foo"),
**kw,
)
@@ -514,32 +617,635 @@ class ComponentReflectionTest(fixtures.TablesTest):
event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v"))
@classmethod
- def define_index(cls, metadata, users):
- Index("users_t_idx", users.c.test1, users.c.test2)
- Index("users_all_idx", users.c.user_id, users.c.test2, users.c.test1)
-
- @classmethod
def define_views(cls, metadata, schema):
- for table_name in ("users", "email_addresses"):
+ if testing.requires.materialized_views.enabled:
+ materialized = {"dingalings"}
+ else:
+ materialized = set()
+ for table_name in ("users", "email_addresses", "dingalings"):
fullname = table_name
if schema:
- fullname = "%s.%s" % (schema, table_name)
+ fullname = f"{schema}.{table_name}"
view_name = fullname + "_v"
- query = "CREATE VIEW %s AS SELECT * FROM %s" % (
- view_name,
- fullname,
+ prefix = "MATERIALIZED " if table_name in materialized else ""
+ query = (
+ f"CREATE {prefix}VIEW {view_name} AS SELECT * FROM {fullname}"
)
event.listen(metadata, "after_create", DDL(query))
+ if table_name in materialized:
+ index_name = "mat_index"
+ if schema and testing.against("oracle"):
+ index_name = f"{schema}.{index_name}"
+ idx = f"CREATE INDEX {index_name} ON {view_name}(data)"
+ event.listen(metadata, "after_create", DDL(idx))
event.listen(
- metadata, "before_drop", DDL("DROP VIEW %s" % view_name)
+ metadata, "before_drop", DDL(f"DROP {prefix}VIEW {view_name}")
+ )
+
+ def _resolve_kind(self, kind, tables, views, materialized):
+ res = {}
+ if ObjectKind.TABLE in kind:
+ res.update(tables)
+ if ObjectKind.VIEW in kind:
+ res.update(views)
+ if ObjectKind.MATERIALIZED_VIEW in kind:
+ res.update(materialized)
+ return res
+
+ def _resolve_views(self, views, materialized):
+ if not testing.requires.view_column_reflection.enabled:
+ materialized.clear()
+ views.clear()
+ elif not testing.requires.materialized_views.enabled:
+ views.update(materialized)
+ materialized.clear()
+
+ def _resolve_names(self, schema, scope, filter_names, values):
+ scope_filter = lambda _: True # noqa: E731
+ if scope is ObjectScope.DEFAULT:
+ scope_filter = lambda k: "tmp" not in k[1] # noqa: E731
+ if scope is ObjectScope.TEMPORARY:
+ scope_filter = lambda k: "tmp" in k[1] # noqa: E731
+
+ removed = {
+ None: {"remote_table", "remote_table_2"},
+ testing.config.test_schema: {
+ "local_table",
+ "noncol_idx_test_nopk",
+ "noncol_idx_test_pk",
+ "user_tmp_v",
+ self.temp_table_name(),
+ },
+ }
+ if not testing.requires.cross_schema_fk_reflection.enabled:
+ removed[None].add("local_table")
+ removed[testing.config.test_schema].update(
+ ["remote_table", "remote_table_2"]
+ )
+ if not testing.requires.index_reflection.enabled:
+ removed[None].update(
+ ["noncol_idx_test_nopk", "noncol_idx_test_pk"]
)
+ if (
+ not testing.requires.temp_table_reflection.enabled
+ or not testing.requires.temp_table_names.enabled
+ ):
+ removed[None].update(["user_tmp_v", self.temp_table_name()])
+ if not testing.requires.temporary_views.enabled:
+ removed[None].update(["user_tmp_v"])
+
+ res = {
+ k: v
+ for k, v in values.items()
+ if scope_filter(k)
+ and k[1] not in removed[schema]
+ and (not filter_names or k[1] in filter_names)
+ }
+ return res
+
+ def exp_options(
+ self,
+ schema=None,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ filter_names=None,
+ ):
+ materialized = {(schema, "dingalings_v"): mock.ANY}
+ views = {
+ (schema, "email_addresses_v"): mock.ANY,
+ (schema, "users_v"): mock.ANY,
+ (schema, "user_tmp_v"): mock.ANY,
+ }
+ self._resolve_views(views, materialized)
+ tables = {
+ (schema, "users"): mock.ANY,
+ (schema, "dingalings"): mock.ANY,
+ (schema, "email_addresses"): mock.ANY,
+ (schema, "comment_test"): mock.ANY,
+ (schema, "no_constraints"): mock.ANY,
+ (schema, "local_table"): mock.ANY,
+ (schema, "remote_table"): mock.ANY,
+ (schema, "remote_table_2"): mock.ANY,
+ (schema, "noncol_idx_test_nopk"): mock.ANY,
+ (schema, "noncol_idx_test_pk"): mock.ANY,
+ (schema, self.temp_table_name()): mock.ANY,
+ }
+ res = self._resolve_kind(kind, tables, views, materialized)
+ res = self._resolve_names(schema, scope, filter_names, res)
+ return res
+
+ def exp_comments(
+ self,
+ schema=None,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ filter_names=None,
+ ):
+ empty = {"text": None}
+ materialized = {(schema, "dingalings_v"): empty}
+ views = {
+ (schema, "email_addresses_v"): empty,
+ (schema, "users_v"): empty,
+ (schema, "user_tmp_v"): empty,
+ }
+ self._resolve_views(views, materialized)
+ tables = {
+ (schema, "users"): empty,
+ (schema, "dingalings"): empty,
+ (schema, "email_addresses"): empty,
+ (schema, "comment_test"): {
+ "text": r"""the test % ' " \ table comment"""
+ },
+ (schema, "no_constraints"): empty,
+ (schema, "local_table"): empty,
+ (schema, "remote_table"): empty,
+ (schema, "remote_table_2"): empty,
+ (schema, "noncol_idx_test_nopk"): empty,
+ (schema, "noncol_idx_test_pk"): empty,
+ (schema, self.temp_table_name()): empty,
+ }
+ res = self._resolve_kind(kind, tables, views, materialized)
+ res = self._resolve_names(schema, scope, filter_names, res)
+ return res
+
+ def exp_columns(
+ self,
+ schema=None,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ filter_names=None,
+ ):
+ def col(
+ name, auto=False, default=mock.ANY, comment=None, nullable=True
+ ):
+ res = {
+ "name": name,
+ "autoincrement": auto,
+ "type": mock.ANY,
+ "default": default,
+ "comment": comment,
+ "nullable": nullable,
+ }
+ if auto == "omit":
+ res.pop("autoincrement")
+ return res
+
+ def pk(name, **kw):
+ kw = {"auto": True, "default": mock.ANY, "nullable": False, **kw}
+ return col(name, **kw)
+
+ materialized = {
+ (schema, "dingalings_v"): [
+ col("dingaling_id", auto="omit", nullable=mock.ANY),
+ col("address_id"),
+ col("data"),
+ ]
+ }
+ views = {
+ (schema, "email_addresses_v"): [
+ col("address_id", auto="omit", nullable=mock.ANY),
+ col("remote_user_id"),
+ col("email_address"),
+ ],
+ (schema, "users_v"): [
+ col("user_id", auto="omit", nullable=mock.ANY),
+ col("test1", nullable=mock.ANY),
+ col("test2", nullable=mock.ANY),
+ col("parent_user_id"),
+ ],
+ (schema, "user_tmp_v"): [
+ col("id", auto="omit", nullable=mock.ANY),
+ col("name"),
+ col("foo"),
+ ],
+ }
+ self._resolve_views(views, materialized)
+ tables = {
+ (schema, "users"): [
+ pk("user_id"),
+ col("test1", nullable=False),
+ col("test2", nullable=False),
+ col("parent_user_id"),
+ ],
+ (schema, "dingalings"): [
+ pk("dingaling_id"),
+ col("address_id"),
+ col("data"),
+ ],
+ (schema, "email_addresses"): [
+ pk("address_id"),
+ col("remote_user_id"),
+ col("email_address"),
+ ],
+ (schema, "comment_test"): [
+ pk("id", comment="id comment"),
+ col("data", comment="data % comment"),
+ col(
+ "d2",
+ comment=r"""Comment types type speedily ' " \ '' Fun!""",
+ ),
+ ],
+ (schema, "no_constraints"): [col("data")],
+ (schema, "local_table"): [pk("id"), col("data"), col("remote_id")],
+ (schema, "remote_table"): [pk("id"), col("local_id"), col("data")],
+ (schema, "remote_table_2"): [pk("id"), col("data")],
+ (schema, "noncol_idx_test_nopk"): [col("q")],
+ (schema, "noncol_idx_test_pk"): [pk("id"), col("q")],
+ (schema, self.temp_table_name()): [
+ pk("id"),
+ col("name"),
+ col("foo"),
+ ],
+ }
+ res = self._resolve_kind(kind, tables, views, materialized)
+ res = self._resolve_names(schema, scope, filter_names, res)
+ return res
+
+ @property
+ def _required_column_keys(self):
+ return {"name", "type", "nullable", "default"}
+
+ def exp_pks(
+ self,
+ schema=None,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ filter_names=None,
+ ):
+ def pk(*cols, name=mock.ANY):
+ return {"constrained_columns": list(cols), "name": name}
+
+ empty = pk(name=None)
+ if testing.requires.materialized_views_reflect_pk.enabled:
+ materialized = {(schema, "dingalings_v"): pk("dingaling_id")}
+ else:
+ materialized = {(schema, "dingalings_v"): empty}
+ views = {
+ (schema, "email_addresses_v"): empty,
+ (schema, "users_v"): empty,
+ (schema, "user_tmp_v"): empty,
+ }
+ self._resolve_views(views, materialized)
+ tables = {
+ (schema, "users"): pk("user_id"),
+ (schema, "dingalings"): pk("dingaling_id"),
+ (schema, "email_addresses"): pk("address_id", name="email_ad_pk"),
+ (schema, "comment_test"): pk("id"),
+ (schema, "no_constraints"): empty,
+ (schema, "local_table"): pk("id"),
+ (schema, "remote_table"): pk("id"),
+ (schema, "remote_table_2"): pk("id"),
+ (schema, "noncol_idx_test_nopk"): empty,
+ (schema, "noncol_idx_test_pk"): pk("id"),
+ (schema, self.temp_table_name()): pk("id"),
+ }
+ if not testing.requires.reflects_pk_names.enabled:
+ for val in tables.values():
+ if val["name"] is not None:
+ val["name"] = mock.ANY
+ res = self._resolve_kind(kind, tables, views, materialized)
+ res = self._resolve_names(schema, scope, filter_names, res)
+ return res
+
+ @property
+ def _required_pk_keys(self):
+ return {"name", "constrained_columns"}
+
+ def exp_fks(
+ self,
+ schema=None,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ filter_names=None,
+ ):
+ class tt:
+ def __eq__(self, other):
+ return (
+ other is None
+ or config.db.dialect.default_schema_name == other
+ )
+
+ def fk(cols, ref_col, ref_table, ref_schema=schema, name=mock.ANY):
+ return {
+ "constrained_columns": cols,
+ "referred_columns": ref_col,
+ "name": name,
+ "options": mock.ANY,
+ "referred_schema": ref_schema
+ if ref_schema is not None
+ else tt(),
+ "referred_table": ref_table,
+ }
+
+ materialized = {(schema, "dingalings_v"): []}
+ views = {
+ (schema, "email_addresses_v"): [],
+ (schema, "users_v"): [],
+ (schema, "user_tmp_v"): [],
+ }
+ self._resolve_views(views, materialized)
+ tables = {
+ (schema, "users"): [
+ fk(["parent_user_id"], ["user_id"], "users", name="user_id_fk")
+ ],
+ (schema, "dingalings"): [
+ fk(
+ ["address_id"],
+ ["address_id"],
+ "email_addresses",
+ name="email_add_id_fg",
+ )
+ ],
+ (schema, "email_addresses"): [
+ fk(["remote_user_id"], ["user_id"], "users")
+ ],
+ (schema, "comment_test"): [],
+ (schema, "no_constraints"): [],
+ (schema, "local_table"): [
+ fk(
+ ["remote_id"],
+ ["id"],
+ "remote_table_2",
+ ref_schema=config.test_schema,
+ )
+ ],
+ (schema, "remote_table"): [
+ fk(["local_id"], ["id"], "local_table", ref_schema=None)
+ ],
+ (schema, "remote_table_2"): [],
+ (schema, "noncol_idx_test_nopk"): [],
+ (schema, "noncol_idx_test_pk"): [],
+ (schema, self.temp_table_name()): [],
+ }
+ if not testing.requires.self_referential_foreign_keys.enabled:
+ tables[(schema, "users")].clear()
+ if not testing.requires.named_constraints.enabled:
+ for vals in tables.values():
+ for val in vals:
+ if val["name"] is not mock.ANY:
+ val["name"] = mock.ANY
+
+ res = self._resolve_kind(kind, tables, views, materialized)
+ res = self._resolve_names(schema, scope, filter_names, res)
+ return res
+
+ @property
+ def _required_fk_keys(self):
+ return {
+ "name",
+ "constrained_columns",
+ "referred_schema",
+ "referred_table",
+ "referred_columns",
+ }
+
+ def exp_indexes(
+ self,
+ schema=None,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ filter_names=None,
+ ):
+ def idx(
+ *cols,
+ name,
+ unique=False,
+ column_sorting=None,
+ duplicates=False,
+ fk=False,
+ ):
+ fk_req = testing.requires.foreign_keys_reflect_as_index
+ dup_req = testing.requires.unique_constraints_reflect_as_index
+ if (fk and not fk_req.enabled) or (
+ duplicates and not dup_req.enabled
+ ):
+ return ()
+ res = {
+ "unique": unique,
+ "column_names": list(cols),
+ "name": name,
+ "dialect_options": mock.ANY,
+ "include_columns": [],
+ }
+ if column_sorting:
+ res["column_sorting"] = {"q": ("desc",)}
+ if duplicates:
+ res["duplicates_constraint"] = name
+ return [res]
+
+ materialized = {(schema, "dingalings_v"): []}
+ views = {
+ (schema, "email_addresses_v"): [],
+ (schema, "users_v"): [],
+ (schema, "user_tmp_v"): [],
+ }
+ self._resolve_views(views, materialized)
+ if materialized:
+ materialized[(schema, "dingalings_v")].extend(
+ idx("data", name="mat_index")
+ )
+ tables = {
+ (schema, "users"): [
+ *idx("parent_user_id", name="user_id_fk", fk=True),
+ *idx("user_id", "test2", "test1", name="users_all_idx"),
+ *idx("test1", "test2", name="users_t_idx", unique=True),
+ ],
+ (schema, "dingalings"): [
+ *idx("data", name=mock.ANY, unique=True, duplicates=True),
+ *idx(
+ "address_id",
+ "dingaling_id",
+ name="zz_dingalings_multiple",
+ unique=True,
+ duplicates=True,
+ ),
+ ],
+ (schema, "email_addresses"): [
+ *idx("email_address", name=mock.ANY),
+ *idx("remote_user_id", name=mock.ANY, fk=True),
+ ],
+ (schema, "comment_test"): [],
+ (schema, "no_constraints"): [],
+ (schema, "local_table"): [
+ *idx("remote_id", name=mock.ANY, fk=True)
+ ],
+ (schema, "remote_table"): [
+ *idx("local_id", name=mock.ANY, fk=True)
+ ],
+ (schema, "remote_table_2"): [],
+ (schema, "noncol_idx_test_nopk"): [
+ *idx(
+ "q",
+ name="noncol_idx_nopk",
+ column_sorting={"q": ("desc",)},
+ )
+ ],
+ (schema, "noncol_idx_test_pk"): [
+ *idx(
+ "q", name="noncol_idx_pk", column_sorting={"q": ("desc",)}
+ )
+ ],
+ (schema, self.temp_table_name()): [
+ *idx("foo", name="user_tmp_ix"),
+ *idx(
+ "name",
+ name=f"user_tmp_uq_{config.ident}",
+ duplicates=True,
+ unique=True,
+ ),
+ ],
+ }
+ if (
+ not testing.requires.indexes_with_ascdesc.enabled
+ or not testing.requires.reflect_indexes_with_ascdesc.enabled
+ ):
+ tables[(schema, "noncol_idx_test_nopk")].clear()
+ tables[(schema, "noncol_idx_test_pk")].clear()
+ res = self._resolve_kind(kind, tables, views, materialized)
+ res = self._resolve_names(schema, scope, filter_names, res)
+ return res
+
+ @property
+ def _required_index_keys(self):
+ return {"name", "column_names", "unique"}
+
+ def exp_ucs(
+ self,
+ schema=None,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ filter_names=None,
+ all_=False,
+ ):
+ def uc(*cols, name, duplicates_index=None, is_index=False):
+ req = testing.requires.unique_index_reflect_as_unique_constraints
+ if is_index and not req.enabled:
+ return ()
+ res = {
+ "column_names": list(cols),
+ "name": name,
+ }
+ if duplicates_index:
+ res["duplicates_index"] = duplicates_index
+ return [res]
+
+ materialized = {(schema, "dingalings_v"): []}
+ views = {
+ (schema, "email_addresses_v"): [],
+ (schema, "users_v"): [],
+ (schema, "user_tmp_v"): [],
+ }
+ self._resolve_views(views, materialized)
+ tables = {
+ (schema, "users"): [
+ *uc(
+ "test1",
+ "test2",
+ name="users_t_idx",
+ duplicates_index="users_t_idx",
+ is_index=True,
+ )
+ ],
+ (schema, "dingalings"): [
+ *uc("data", name=mock.ANY, duplicates_index=mock.ANY),
+ *uc(
+ "address_id",
+ "dingaling_id",
+ name="zz_dingalings_multiple",
+ duplicates_index="zz_dingalings_multiple",
+ ),
+ ],
+ (schema, "email_addresses"): [],
+ (schema, "comment_test"): [],
+ (schema, "no_constraints"): [],
+ (schema, "local_table"): [],
+ (schema, "remote_table"): [],
+ (schema, "remote_table_2"): [],
+ (schema, "noncol_idx_test_nopk"): [],
+ (schema, "noncol_idx_test_pk"): [],
+ (schema, self.temp_table_name()): [
+ *uc("name", name=f"user_tmp_uq_{config.ident}")
+ ],
+ }
+ if all_:
+ return {**materialized, **views, **tables}
+ else:
+ res = self._resolve_kind(kind, tables, views, materialized)
+ res = self._resolve_names(schema, scope, filter_names, res)
+ return res
+
+ @property
+ def _required_unique_cst_keys(self):
+ return {"name", "column_names"}
+
+ def exp_ccs(
+ self,
+ schema=None,
+ scope=ObjectScope.ANY,
+ kind=ObjectKind.ANY,
+ filter_names=None,
+ ):
+ class tt(str):
+ def __eq__(self, other):
+ res = (
+ other.lower()
+ .replace("(", "")
+ .replace(")", "")
+ .replace("`", "")
+ )
+ return self in res
+
+ def cc(text, name):
+ return {"sqltext": tt(text), "name": name}
+
+ # print({1: "test2 > (0)::double precision"} == {1: tt("test2 > 0")})
+ # assert 0
+ materialized = {(schema, "dingalings_v"): []}
+ views = {
+ (schema, "email_addresses_v"): [],
+ (schema, "users_v"): [],
+ (schema, "user_tmp_v"): [],
+ }
+ self._resolve_views(views, materialized)
+ tables = {
+ (schema, "users"): [cc("test2 > 0", "test2_gt_zero")],
+ (schema, "dingalings"): [
+ cc(
+ "address_id > 0 and address_id < 1000",
+ name="address_id_gt_zero",
+ ),
+ ],
+ (schema, "email_addresses"): [],
+ (schema, "comment_test"): [],
+ (schema, "no_constraints"): [],
+ (schema, "local_table"): [],
+ (schema, "remote_table"): [],
+ (schema, "remote_table_2"): [],
+ (schema, "noncol_idx_test_nopk"): [],
+ (schema, "noncol_idx_test_pk"): [],
+ (schema, self.temp_table_name()): [],
+ }
+ res = self._resolve_kind(kind, tables, views, materialized)
+ res = self._resolve_names(schema, scope, filter_names, res)
+ return res
+
+ @property
+ def _required_cc_keys(self):
+ return {"name", "sqltext"}
@testing.requires.schema_reflection
- def test_get_schema_names(self):
- insp = inspect(self.bind)
+ def test_get_schema_names(self, connection):
+ insp = inspect(connection)
- self.assert_(testing.config.test_schema in insp.get_schema_names())
+ is_true(testing.config.test_schema in insp.get_schema_names())
+
+ @testing.requires.schema_reflection
+ def test_has_schema(self, connection):
+ insp = inspect(connection)
+
+ is_true(insp.has_schema(testing.config.test_schema))
+ is_false(insp.has_schema("sa_fake_schema_foo"))
@testing.requires.schema_reflection
def test_get_schema_names_w_translate_map(self, connection):
@@ -553,7 +1259,37 @@ class ComponentReflectionTest(fixtures.TablesTest):
)
insp = inspect(connection)
- self.assert_(testing.config.test_schema in insp.get_schema_names())
+ is_true(testing.config.test_schema in insp.get_schema_names())
+
+ @testing.requires.schema_reflection
+ def test_has_schema_w_translate_map(self, connection):
+ connection = connection.execution_options(
+ schema_translate_map={
+ "foo": "bar",
+ BLANK_SCHEMA: testing.config.test_schema,
+ }
+ )
+ insp = inspect(connection)
+
+ is_true(insp.has_schema(testing.config.test_schema))
+ is_false(insp.has_schema("sa_fake_schema_foo"))
+
+ @testing.requires.schema_reflection
+ @testing.requires.schema_create_delete
+ def test_schema_cache(self, connection):
+ insp = inspect(connection)
+
+ is_false("foo_bar" in insp.get_schema_names())
+ is_false(insp.has_schema("foo_bar"))
+ connection.execute(DDL("CREATE SCHEMA foo_bar"))
+ try:
+ is_false("foo_bar" in insp.get_schema_names())
+ is_false(insp.has_schema("foo_bar"))
+ insp.clear_cache()
+ is_true("foo_bar" in insp.get_schema_names())
+ is_true(insp.has_schema("foo_bar"))
+ finally:
+ connection.execute(DDL("DROP SCHEMA foo_bar"))
@testing.requires.schema_reflection
def test_dialect_initialize(self):
@@ -562,113 +1298,115 @@ class ComponentReflectionTest(fixtures.TablesTest):
assert hasattr(engine.dialect, "default_schema_name")
@testing.requires.schema_reflection
- def test_get_default_schema_name(self):
- insp = inspect(self.bind)
- eq_(insp.default_schema_name, self.bind.dialect.default_schema_name)
+ def test_get_default_schema_name(self, connection):
+ insp = inspect(connection)
+ eq_(insp.default_schema_name, connection.dialect.default_schema_name)
- @testing.requires.foreign_key_constraint_reflection
@testing.combinations(
- (None, True, False, False),
- (None, True, False, True, testing.requires.schemas),
- ("foreign_key", True, False, False),
- (None, False, True, False),
- (None, False, True, True, testing.requires.schemas),
- (None, True, True, False),
- (None, True, True, True, testing.requires.schemas),
- argnames="order_by,include_plain,include_views,use_schema",
+ None,
+ ("foreign_key", testing.requires.foreign_key_constraint_reflection),
+ argnames="order_by",
)
- def test_get_table_names(
- self, connection, order_by, include_plain, include_views, use_schema
- ):
+ @testing.combinations(
+ (True, testing.requires.schemas), False, argnames="use_schema"
+ )
+ def test_get_table_names(self, connection, order_by, use_schema):
if use_schema:
schema = config.test_schema
else:
schema = None
- _ignore_tables = [
+ _ignore_tables = {
"comment_test",
"noncol_idx_test_pk",
"noncol_idx_test_nopk",
"local_table",
"remote_table",
"remote_table_2",
- ]
+ "no_constraints",
+ }
insp = inspect(connection)
- if include_views:
- table_names = insp.get_view_names(schema)
- table_names.sort()
- answer = ["email_addresses_v", "users_v"]
- eq_(sorted(table_names), answer)
+ if order_by:
+ tables = [
+ rec[0]
+ for rec in insp.get_sorted_table_and_fkc_names(schema)
+ if rec[0]
+ ]
+ else:
+ tables = insp.get_table_names(schema)
+ table_names = [t for t in tables if t not in _ignore_tables]
- if include_plain:
- if order_by:
- tables = [
- rec[0]
- for rec in insp.get_sorted_table_and_fkc_names(schema)
- if rec[0]
- ]
- else:
- tables = insp.get_table_names(schema)
- table_names = [t for t in tables if t not in _ignore_tables]
+ if order_by == "foreign_key":
+ answer = ["users", "email_addresses", "dingalings"]
+ eq_(table_names, answer)
+ else:
+ answer = ["dingalings", "email_addresses", "users"]
+ eq_(sorted(table_names), answer)
- if order_by == "foreign_key":
- answer = ["users", "email_addresses", "dingalings"]
- eq_(table_names, answer)
- else:
- answer = ["dingalings", "email_addresses", "users"]
- eq_(sorted(table_names), answer)
+ @testing.combinations(
+ (True, testing.requires.schemas), False, argnames="use_schema"
+ )
+ def test_get_view_names(self, connection, use_schema):
+ insp = inspect(connection)
+ if use_schema:
+ schema = config.test_schema
+ else:
+ schema = None
+ table_names = insp.get_view_names(schema)
+ if testing.requires.materialized_views.enabled:
+ eq_(sorted(table_names), ["email_addresses_v", "users_v"])
+ eq_(insp.get_materialized_view_names(schema), ["dingalings_v"])
+ else:
+ answer = ["dingalings_v", "email_addresses_v", "users_v"]
+ eq_(sorted(table_names), answer)
@testing.requires.temp_table_names
- def test_get_temp_table_names(self):
- insp = inspect(self.bind)
+ def test_get_temp_table_names(self, connection):
+ insp = inspect(connection)
temp_table_names = insp.get_temp_table_names()
- eq_(sorted(temp_table_names), ["user_tmp_%s" % config.ident])
+ eq_(sorted(temp_table_names), [f"user_tmp_{config.ident}"])
@testing.requires.view_reflection
- @testing.requires.temp_table_names
@testing.requires.temporary_views
- def test_get_temp_view_names(self):
- insp = inspect(self.bind)
+ def test_get_temp_view_names(self, connection):
+ insp = inspect(connection)
temp_table_names = insp.get_temp_view_names()
eq_(sorted(temp_table_names), ["user_tmp_v"])
@testing.requires.comment_reflection
- def test_get_comments(self):
- self._test_get_comments()
+ def test_get_comments(self, connection):
+ self._test_get_comments(connection)
@testing.requires.comment_reflection
@testing.requires.schemas
- def test_get_comments_with_schema(self):
- self._test_get_comments(testing.config.test_schema)
-
- def _test_get_comments(self, schema=None):
- insp = inspect(self.bind)
+ def test_get_comments_with_schema(self, connection):
+ self._test_get_comments(connection, testing.config.test_schema)
+ def _test_get_comments(self, connection, schema=None):
+ insp = inspect(connection)
+ exp = self.exp_comments(schema=schema)
eq_(
insp.get_table_comment("comment_test", schema=schema),
- {"text": r"""the test % ' " \ table comment"""},
+ exp[(schema, "comment_test")],
)
- eq_(insp.get_table_comment("users", schema=schema), {"text": None})
+ eq_(
+ insp.get_table_comment("users", schema=schema),
+ exp[(schema, "users")],
+ )
eq_(
- [
- {"name": rec["name"], "comment": rec["comment"]}
- for rec in insp.get_columns("comment_test", schema=schema)
- ],
- [
- {"comment": "id comment", "name": "id"},
- {"comment": "data % comment", "name": "data"},
- {
- "comment": (
- r"""Comment types type speedily ' " \ '' Fun!"""
- ),
- "name": "d2",
- },
- ],
+ insp.get_table_comment("comment_test", schema=schema),
+ exp[(schema, "comment_test")],
+ )
+
+ no_cst = self.tables.no_constraints.name
+ eq_(
+ insp.get_table_comment(no_cst, schema=schema),
+ exp[(schema, no_cst)],
)
@testing.combinations(
@@ -691,7 +1429,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
users, addresses = (self.tables.users, self.tables.email_addresses)
if use_views:
- table_names = ["users_v", "email_addresses_v"]
+ table_names = ["users_v", "email_addresses_v", "dingalings_v"]
else:
table_names = ["users", "email_addresses"]
@@ -699,7 +1437,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
for table_name, table in zip(table_names, (users, addresses)):
schema_name = schema
cols = insp.get_columns(table_name, schema=schema_name)
- self.assert_(len(cols) > 0, len(cols))
+ is_true(len(cols) > 0, len(cols))
# should be in order
@@ -721,7 +1459,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
# assert that the desired type and return type share
# a base within one of the generic types.
- self.assert_(
+ is_true(
len(
set(ctype.__mro__)
.intersection(ctype_def.__mro__)
@@ -745,15 +1483,29 @@ class ComponentReflectionTest(fixtures.TablesTest):
if not col.primary_key:
assert cols[i]["default"] is None
+ # The case of a table with no column
+ # is tested below in TableNoColumnsTest
+
@testing.requires.temp_table_reflection
- def test_get_temp_table_columns(self):
- table_name = get_temp_table_name(
- config, self.bind, "user_tmp_%s" % config.ident
+ def test_reflect_table_temp_table(self, connection):
+
+ table_name = self.temp_table_name()
+ user_tmp = self.tables[table_name]
+
+ reflected_user_tmp = Table(
+ table_name, MetaData(), autoload_with=connection
)
+ self.assert_tables_equal(
+ user_tmp, reflected_user_tmp, strict_constraints=False
+ )
+
+ @testing.requires.temp_table_reflection
+ def test_get_temp_table_columns(self, connection):
+ table_name = self.temp_table_name()
user_tmp = self.tables[table_name]
- insp = inspect(self.bind)
+ insp = inspect(connection)
cols = insp.get_columns(table_name)
- self.assert_(len(cols) > 0, len(cols))
+ is_true(len(cols) > 0, len(cols))
for i, col in enumerate(user_tmp.columns):
eq_(col.name, cols[i]["name"])
@@ -761,8 +1513,8 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.temp_table_reflection
@testing.requires.view_column_reflection
@testing.requires.temporary_views
- def test_get_temp_view_columns(self):
- insp = inspect(self.bind)
+ def test_get_temp_view_columns(self, connection):
+ insp = inspect(connection)
cols = insp.get_columns("user_tmp_v")
eq_([col["name"] for col in cols], ["id", "name", "foo"])
@@ -778,18 +1530,27 @@ class ComponentReflectionTest(fixtures.TablesTest):
users, addresses = self.tables.users, self.tables.email_addresses
insp = inspect(connection)
+ exp = self.exp_pks(schema=schema)
users_cons = insp.get_pk_constraint(users.name, schema=schema)
- users_pkeys = users_cons["constrained_columns"]
- eq_(users_pkeys, ["user_id"])
+ self._check_list(
+ [users_cons], [exp[(schema, users.name)]], self._required_pk_keys
+ )
addr_cons = insp.get_pk_constraint(addresses.name, schema=schema)
- addr_pkeys = addr_cons["constrained_columns"]
- eq_(addr_pkeys, ["address_id"])
+ exp_cols = exp[(schema, addresses.name)]["constrained_columns"]
+ eq_(addr_cons["constrained_columns"], exp_cols)
with testing.requires.reflects_pk_names.fail_if():
eq_(addr_cons["name"], "email_ad_pk")
+ no_cst = self.tables.no_constraints.name
+ self._check_list(
+ [insp.get_pk_constraint(no_cst, schema=schema)],
+ [exp[(schema, no_cst)]],
+ self._required_pk_keys,
+ )
+
@testing.combinations(
(False,), (True, testing.requires.schemas), argnames="use_schema"
)
@@ -815,31 +1576,33 @@ class ComponentReflectionTest(fixtures.TablesTest):
eq_(fkey1["referred_schema"], expected_schema)
eq_(fkey1["referred_table"], users.name)
eq_(fkey1["referred_columns"], ["user_id"])
- if testing.requires.self_referential_foreign_keys.enabled:
- eq_(fkey1["constrained_columns"], ["parent_user_id"])
+ eq_(fkey1["constrained_columns"], ["parent_user_id"])
# addresses
addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema)
fkey1 = addr_fkeys[0]
with testing.requires.implicitly_named_constraints.fail_if():
- self.assert_(fkey1["name"] is not None)
+ is_true(fkey1["name"] is not None)
eq_(fkey1["referred_schema"], expected_schema)
eq_(fkey1["referred_table"], users.name)
eq_(fkey1["referred_columns"], ["user_id"])
eq_(fkey1["constrained_columns"], ["remote_user_id"])
+ no_cst = self.tables.no_constraints.name
+ eq_(insp.get_foreign_keys(no_cst, schema=schema), [])
+
@testing.requires.cross_schema_fk_reflection
@testing.requires.schemas
- def test_get_inter_schema_foreign_keys(self):
+ def test_get_inter_schema_foreign_keys(self, connection):
local_table, remote_table, remote_table_2 = self.tables(
- "%s.local_table" % self.bind.dialect.default_schema_name,
+ "%s.local_table" % connection.dialect.default_schema_name,
"%s.remote_table" % testing.config.test_schema,
"%s.remote_table_2" % testing.config.test_schema,
)
- insp = inspect(self.bind)
+ insp = inspect(connection)
local_fkeys = insp.get_foreign_keys(local_table.name)
eq_(len(local_fkeys), 1)
@@ -857,25 +1620,21 @@ class ComponentReflectionTest(fixtures.TablesTest):
fkey2 = remote_fkeys[0]
- assert fkey2["referred_schema"] in (
- None,
- self.bind.dialect.default_schema_name,
+ is_true(
+ fkey2["referred_schema"]
+ in (
+ None,
+ connection.dialect.default_schema_name,
+ )
)
eq_(fkey2["referred_table"], local_table.name)
eq_(fkey2["referred_columns"], ["id"])
eq_(fkey2["constrained_columns"], ["local_id"])
- def _assert_insp_indexes(self, indexes, expected_indexes):
- index_names = [d["name"] for d in indexes]
- for e_index in expected_indexes:
- assert e_index["name"] in index_names
- index = indexes[index_names.index(e_index["name"])]
- for key in e_index:
- eq_(e_index[key], index[key])
-
@testing.combinations(
(False,), (True, testing.requires.schemas), argnames="use_schema"
)
+ @testing.requires.index_reflection
def test_get_indexes(self, connection, use_schema):
if use_schema:
@@ -885,21 +1644,19 @@ class ComponentReflectionTest(fixtures.TablesTest):
# The database may decide to create indexes for foreign keys, etc.
# so there may be more indexes than expected.
- insp = inspect(self.bind)
+ insp = inspect(connection)
indexes = insp.get_indexes("users", schema=schema)
- expected_indexes = [
- {
- "unique": False,
- "column_names": ["test1", "test2"],
- "name": "users_t_idx",
- },
- {
- "unique": False,
- "column_names": ["user_id", "test2", "test1"],
- "name": "users_all_idx",
- },
- ]
- self._assert_insp_indexes(indexes, expected_indexes)
+ exp = self.exp_indexes(schema=schema)
+ self._check_list(
+ indexes, exp[(schema, "users")], self._required_index_keys
+ )
+
+ no_cst = self.tables.no_constraints.name
+ self._check_list(
+ insp.get_indexes(no_cst, schema=schema),
+ exp[(schema, no_cst)],
+ self._required_index_keys,
+ )
@testing.combinations(
("noncol_idx_test_nopk", "noncol_idx_nopk"),
@@ -908,15 +1665,15 @@ class ComponentReflectionTest(fixtures.TablesTest):
)
@testing.requires.index_reflection
@testing.requires.indexes_with_ascdesc
+ @testing.requires.reflect_indexes_with_ascdesc
def test_get_noncol_index(self, connection, tname, ixname):
insp = inspect(connection)
indexes = insp.get_indexes(tname)
-
# reflecting an index that has "x DESC" in it as the column.
# the DB may or may not give us "x", but make sure we get the index
# back, it has a name, it's connected to the table.
- expected_indexes = [{"unique": False, "name": ixname}]
- self._assert_insp_indexes(indexes, expected_indexes)
+ expected_indexes = self.exp_indexes()[(None, tname)]
+ self._check_list(indexes, expected_indexes, self._required_index_keys)
t = Table(tname, MetaData(), autoload_with=connection)
eq_(len(t.indexes), 1)
@@ -925,29 +1682,17 @@ class ComponentReflectionTest(fixtures.TablesTest):
@testing.requires.temp_table_reflection
@testing.requires.unique_constraint_reflection
- def test_get_temp_table_unique_constraints(self):
- insp = inspect(self.bind)
- reflected = insp.get_unique_constraints("user_tmp_%s" % config.ident)
- for refl in reflected:
- # Different dialects handle duplicate index and constraints
- # differently, so ignore this flag
- refl.pop("duplicates_index", None)
- eq_(
- reflected,
- [
- {
- "column_names": ["name"],
- "name": "user_tmp_uq_%s" % config.ident,
- }
- ],
- )
+ def test_get_temp_table_unique_constraints(self, connection):
+ insp = inspect(connection)
+ name = self.temp_table_name()
+ reflected = insp.get_unique_constraints(name)
+ exp = self.exp_ucs(all_=True)[(None, name)]
+ self._check_list(reflected, exp, self._required_index_keys)
@testing.requires.temp_table_reflect_indexes
- def test_get_temp_table_indexes(self):
- insp = inspect(self.bind)
- table_name = get_temp_table_name(
- config, config.db, "user_tmp_%s" % config.ident
- )
+ def test_get_temp_table_indexes(self, connection):
+ insp = inspect(connection)
+ table_name = self.temp_table_name()
indexes = insp.get_indexes(table_name)
for ind in indexes:
ind.pop("dialect_options", None)
@@ -1005,9 +1750,9 @@ class ComponentReflectionTest(fixtures.TablesTest):
)
table.create(connection)
- inspector = inspect(connection)
+ insp = inspect(connection)
reflected = sorted(
- inspector.get_unique_constraints("testtbl", schema=schema),
+ insp.get_unique_constraints("testtbl", schema=schema),
key=operator.itemgetter("name"),
)
@@ -1047,6 +1792,9 @@ class ComponentReflectionTest(fixtures.TablesTest):
eq_(names_that_duplicate_index, idx_names)
eq_(uq_names, set())
+ no_cst = self.tables.no_constraints.name
+ eq_(insp.get_unique_constraints(no_cst, schema=schema), [])
+
@testing.requires.view_reflection
@testing.combinations(
(False,), (True, testing.requires.schemas), argnames="use_schema"
@@ -1056,32 +1804,21 @@ class ComponentReflectionTest(fixtures.TablesTest):
schema = config.test_schema
else:
schema = None
- view_name1 = "users_v"
- view_name2 = "email_addresses_v"
insp = inspect(connection)
- v1 = insp.get_view_definition(view_name1, schema=schema)
- self.assert_(v1)
- v2 = insp.get_view_definition(view_name2, schema=schema)
- self.assert_(v2)
+ for view in ["users_v", "email_addresses_v", "dingalings_v"]:
+ v = insp.get_view_definition(view, schema=schema)
+ is_true(bool(v))
- # why is this here if it's PG specific ?
- @testing.combinations(
- ("users", False),
- ("users", True, testing.requires.schemas),
- argnames="table_name,use_schema",
- )
- @testing.only_on("postgresql", "PG specific feature")
- def test_get_table_oid(self, connection, table_name, use_schema):
- if use_schema:
- schema = config.test_schema
- else:
- schema = None
+ @testing.requires.view_reflection
+ def test_get_view_definition_does_not_exist(self, connection):
insp = inspect(connection)
- oid = insp.get_table_oid(table_name, schema)
- self.assert_(isinstance(oid, int))
+ with expect_raises(NoSuchTableError):
+ insp.get_view_definition("view_does_not_exist")
+ with expect_raises(NoSuchTableError):
+ insp.get_view_definition("users") # a table
@testing.requires.table_reflection
- def test_autoincrement_col(self):
+ def test_autoincrement_col(self, connection):
"""test that 'autoincrement' is reflected according to sqla's policy.
Don't mark this test as unsupported for any backend !
@@ -1094,7 +1831,7 @@ class ComponentReflectionTest(fixtures.TablesTest):
"""
- insp = inspect(self.bind)
+ insp = inspect(connection)
for tname, cname in [
("users", "user_id"),
@@ -1105,6 +1842,330 @@ class ComponentReflectionTest(fixtures.TablesTest):
id_ = {c["name"]: c for c in cols}[cname]
assert id_.get("autoincrement", True)
+ @testing.combinations(
+ (True, testing.requires.schemas), (False,), argnames="use_schema"
+ )
+ def test_get_table_options(self, use_schema):
+ insp = inspect(config.db)
+ schema = config.test_schema if use_schema else None
+
+ if testing.requires.reflect_table_options.enabled:
+ res = insp.get_table_options("users", schema=schema)
+ is_true(isinstance(res, dict))
+ # NOTE: can't really create a table with no option
+ res = insp.get_table_options("no_constraints", schema=schema)
+ is_true(isinstance(res, dict))
+ else:
+ with expect_raises(NotImplementedError):
+ res = insp.get_table_options("users", schema=schema)
+
+ @testing.combinations((True, testing.requires.schemas), False)
+ def test_multi_get_table_options(self, use_schema):
+ insp = inspect(config.db)
+ if testing.requires.reflect_table_options.enabled:
+ schema = config.test_schema if use_schema else None
+ res = insp.get_multi_table_options(schema=schema)
+
+ exp = {
+ (schema, table): insp.get_table_options(table, schema=schema)
+ for table in insp.get_table_names(schema=schema)
+ }
+ eq_(res, exp)
+ else:
+ with expect_raises(NotImplementedError):
+ res = insp.get_multi_table_options()
+
+ @testing.fixture
+ def get_multi_exp(self, connection):
+ def provide_fixture(
+ schema, scope, kind, use_filter, single_reflect_fn, exp_method
+ ):
+ insp = inspect(connection)
+ # call the reflection function at least once to avoid
+ # "Unexpected success" errors if the result is actually empty
+ # and NotImplementedError is not raised
+ single_reflect_fn(insp, "email_addresses")
+ kw = {"scope": scope, "kind": kind}
+ if schema:
+ schema = schema()
+
+ filter_names = []
+
+ if ObjectKind.TABLE in kind:
+ filter_names.extend(
+ ["comment_test", "users", "does-not-exist"]
+ )
+ if ObjectKind.VIEW in kind:
+ filter_names.extend(["email_addresses_v", "does-not-exist"])
+ if ObjectKind.MATERIALIZED_VIEW in kind:
+ filter_names.extend(["dingalings_v", "does-not-exist"])
+
+ if schema:
+ kw["schema"] = schema
+ if use_filter:
+ kw["filter_names"] = filter_names
+
+ exp = exp_method(
+ schema=schema,
+ scope=scope,
+ kind=kind,
+ filter_names=kw.get("filter_names"),
+ )
+ kws = [kw]
+ if scope == ObjectScope.DEFAULT:
+ nkw = kw.copy()
+ nkw.pop("scope")
+ kws.append(nkw)
+ if kind == ObjectKind.TABLE:
+ nkw = kw.copy()
+ nkw.pop("kind")
+ kws.append(nkw)
+
+ return inspect(connection), kws, exp
+
+ return provide_fixture
+
+ @testing.requires.reflect_table_options
+ @_multi_combination
+ def test_multi_get_table_options_tables(
+ self, get_multi_exp, schema, scope, kind, use_filter
+ ):
+ insp, kws, exp = get_multi_exp(
+ schema,
+ scope,
+ kind,
+ use_filter,
+ Inspector.get_table_options,
+ self.exp_options,
+ )
+ for kw in kws:
+ insp.clear_cache()
+ result = insp.get_multi_table_options(**kw)
+ eq_(result, exp)
+
+ @testing.requires.comment_reflection
+ @_multi_combination
+ def test_get_multi_table_comment(
+ self, get_multi_exp, schema, scope, kind, use_filter
+ ):
+ insp, kws, exp = get_multi_exp(
+ schema,
+ scope,
+ kind,
+ use_filter,
+ Inspector.get_table_comment,
+ self.exp_comments,
+ )
+ for kw in kws:
+ insp.clear_cache()
+ eq_(insp.get_multi_table_comment(**kw), exp)
+
+ def _check_list(self, result, exp, req_keys=None, msg=None):
+ if req_keys is None:
+ eq_(result, exp, msg)
+ else:
+ eq_(len(result), len(exp), msg)
+ for r, e in zip(result, exp):
+ for k in set(r) | set(e):
+ if k in req_keys or (k in r and k in e):
+ eq_(r[k], e[k], f"{msg} - {k} - {r}")
+
+ def _check_table_dict(self, result, exp, req_keys=None, make_lists=False):
+ eq_(set(result.keys()), set(exp.keys()))
+ for k in result:
+ r, e = result[k], exp[k]
+ if make_lists:
+ r, e = [r], [e]
+ self._check_list(r, e, req_keys, k)
+
+ @_multi_combination
+ def test_get_multi_columns(
+ self, get_multi_exp, schema, scope, kind, use_filter
+ ):
+ insp, kws, exp = get_multi_exp(
+ schema,
+ scope,
+ kind,
+ use_filter,
+ Inspector.get_columns,
+ self.exp_columns,
+ )
+
+ for kw in kws:
+ insp.clear_cache()
+ result = insp.get_multi_columns(**kw)
+ self._check_table_dict(result, exp, self._required_column_keys)
+
+ @testing.requires.primary_key_constraint_reflection
+ @_multi_combination
+ def test_get_multi_pk_constraint(
+ self, get_multi_exp, schema, scope, kind, use_filter
+ ):
+ insp, kws, exp = get_multi_exp(
+ schema,
+ scope,
+ kind,
+ use_filter,
+ Inspector.get_pk_constraint,
+ self.exp_pks,
+ )
+ for kw in kws:
+ insp.clear_cache()
+ result = insp.get_multi_pk_constraint(**kw)
+ self._check_table_dict(
+ result, exp, self._required_pk_keys, make_lists=True
+ )
+
+ def _adjust_sort(self, result, expected, key):
+ if not testing.requires.implicitly_named_constraints.enabled:
+ for obj in [result, expected]:
+ for val in obj.values():
+ if len(val) > 1 and any(
+ v.get("name") in (None, mock.ANY) for v in val
+ ):
+ val.sort(key=key)
+
+ @testing.requires.foreign_key_constraint_reflection
+ @_multi_combination
+ def test_get_multi_foreign_keys(
+ self, get_multi_exp, schema, scope, kind, use_filter
+ ):
+ insp, kws, exp = get_multi_exp(
+ schema,
+ scope,
+ kind,
+ use_filter,
+ Inspector.get_foreign_keys,
+ self.exp_fks,
+ )
+ for kw in kws:
+ insp.clear_cache()
+ result = insp.get_multi_foreign_keys(**kw)
+ self._adjust_sort(
+ result, exp, lambda d: tuple(d["constrained_columns"])
+ )
+ self._check_table_dict(result, exp, self._required_fk_keys)
+
+ @testing.requires.index_reflection
+ @_multi_combination
+ def test_get_multi_indexes(
+ self, get_multi_exp, schema, scope, kind, use_filter
+ ):
+ insp, kws, exp = get_multi_exp(
+ schema,
+ scope,
+ kind,
+ use_filter,
+ Inspector.get_indexes,
+ self.exp_indexes,
+ )
+ for kw in kws:
+ insp.clear_cache()
+ result = insp.get_multi_indexes(**kw)
+ self._check_table_dict(result, exp, self._required_index_keys)
+
+ @testing.requires.unique_constraint_reflection
+ @_multi_combination
+ def test_get_multi_unique_constraints(
+ self, get_multi_exp, schema, scope, kind, use_filter
+ ):
+ insp, kws, exp = get_multi_exp(
+ schema,
+ scope,
+ kind,
+ use_filter,
+ Inspector.get_unique_constraints,
+ self.exp_ucs,
+ )
+ for kw in kws:
+ insp.clear_cache()
+ result = insp.get_multi_unique_constraints(**kw)
+ self._adjust_sort(result, exp, lambda d: tuple(d["column_names"]))
+ self._check_table_dict(result, exp, self._required_unique_cst_keys)
+
+ @testing.requires.check_constraint_reflection
+ @_multi_combination
+ def test_get_multi_check_constraints(
+ self, get_multi_exp, schema, scope, kind, use_filter
+ ):
+ insp, kws, exp = get_multi_exp(
+ schema,
+ scope,
+ kind,
+ use_filter,
+ Inspector.get_check_constraints,
+ self.exp_ccs,
+ )
+ for kw in kws:
+ insp.clear_cache()
+ result = insp.get_multi_check_constraints(**kw)
+ self._adjust_sort(result, exp, lambda d: tuple(d["sqltext"]))
+ self._check_table_dict(result, exp, self._required_cc_keys)
+
+ @testing.combinations(
+ ("get_table_options", testing.requires.reflect_table_options),
+ "get_columns",
+ (
+ "get_pk_constraint",
+ testing.requires.primary_key_constraint_reflection,
+ ),
+ (
+ "get_foreign_keys",
+ testing.requires.foreign_key_constraint_reflection,
+ ),
+ ("get_indexes", testing.requires.index_reflection),
+ (
+ "get_unique_constraints",
+ testing.requires.unique_constraint_reflection,
+ ),
+ (
+ "get_check_constraints",
+ testing.requires.check_constraint_reflection,
+ ),
+ ("get_table_comment", testing.requires.comment_reflection),
+ argnames="method",
+ )
+ def test_not_existing_table(self, method, connection):
+ insp = inspect(connection)
+ meth = getattr(insp, method)
+ with expect_raises(NoSuchTableError):
+ meth("table_does_not_exists")
+
+ def test_unreflectable(self, connection):
+ mc = Inspector.get_multi_columns
+
+ def patched(*a, **k):
+ ur = k.setdefault("unreflectable", {})
+ ur[(None, "some_table")] = UnreflectableTableError("err")
+ return mc(*a, **k)
+
+ with mock.patch.object(Inspector, "get_multi_columns", patched):
+ with expect_raises_message(UnreflectableTableError, "err"):
+ inspect(connection).reflect_table(
+ Table("some_table", MetaData()), None
+ )
+
+ @testing.combinations(True, False, argnames="use_schema")
+ @testing.combinations(
+ (True, testing.requires.views), False, argnames="views"
+ )
+ def test_metadata(self, connection, use_schema, views):
+ m = MetaData()
+ schema = config.test_schema if use_schema else None
+ m.reflect(connection, schema=schema, views=views, resolve_fks=False)
+
+ insp = inspect(connection)
+ tables = insp.get_table_names(schema)
+ if views:
+ tables += insp.get_view_names(schema)
+ try:
+ tables += insp.get_materialized_view_names(schema)
+ except NotImplementedError:
+ pass
+ if schema:
+ tables = [f"{schema}.{t}" for t in tables]
+ eq_(sorted(m.tables), sorted(tables))
+
class TableNoColumnsTest(fixtures.TestBase):
__requires__ = ("reflect_tables_no_columns",)
@@ -1118,9 +2179,6 @@ class TableNoColumnsTest(fixtures.TestBase):
@testing.fixture
def view_no_columns(self, connection, metadata):
Table("empty", metadata)
- metadata.create_all(connection)
-
- Table("empty", metadata)
event.listen(
metadata,
"after_create",
@@ -1134,31 +2192,32 @@ class TableNoColumnsTest(fixtures.TestBase):
)
metadata.create_all(connection)
- @testing.requires.reflect_tables_no_columns
def test_reflect_table_no_columns(self, connection, table_no_columns):
t2 = Table("empty", MetaData(), autoload_with=connection)
eq_(list(t2.c), [])
- @testing.requires.reflect_tables_no_columns
def test_get_columns_table_no_columns(self, connection, table_no_columns):
- eq_(inspect(connection).get_columns("empty"), [])
+ insp = inspect(connection)
+ eq_(insp.get_columns("empty"), [])
+ multi = insp.get_multi_columns()
+ eq_(multi, {(None, "empty"): []})
- @testing.requires.reflect_tables_no_columns
def test_reflect_incl_table_no_columns(self, connection, table_no_columns):
m = MetaData()
m.reflect(connection)
assert set(m.tables).intersection(["empty"])
@testing.requires.views
- @testing.requires.reflect_tables_no_columns
def test_reflect_view_no_columns(self, connection, view_no_columns):
t2 = Table("empty_v", MetaData(), autoload_with=connection)
eq_(list(t2.c), [])
@testing.requires.views
- @testing.requires.reflect_tables_no_columns
def test_get_columns_view_no_columns(self, connection, view_no_columns):
- eq_(inspect(connection).get_columns("empty_v"), [])
+ insp = inspect(connection)
+ eq_(insp.get_columns("empty_v"), [])
+ multi = insp.get_multi_columns(kind=ObjectKind.VIEW)
+ eq_(multi, {(None, "empty_v"): []})
class ComponentReflectionTestExtra(fixtures.TestBase):
@@ -1185,12 +2244,18 @@ class ComponentReflectionTestExtra(fixtures.TestBase):
),
schema=schema,
)
+ Table(
+ "no_constraints",
+ metadata,
+ Column("data", sa.String(20)),
+ schema=schema,
+ )
metadata.create_all(connection)
- inspector = inspect(connection)
+ insp = inspect(connection)
reflected = sorted(
- inspector.get_check_constraints("sa_cc", schema=schema),
+ insp.get_check_constraints("sa_cc", schema=schema),
key=operator.itemgetter("name"),
)
@@ -1213,6 +2278,8 @@ class ComponentReflectionTestExtra(fixtures.TestBase):
{"name": "cc1", "sqltext": "a > 1 and a < 5"},
],
)
+ no_cst = "no_constraints"
+ eq_(insp.get_check_constraints(no_cst, schema=schema), [])
@testing.requires.indexes_with_expressions
def test_reflect_expression_based_indexes(self, metadata, connection):
@@ -1642,7 +2709,8 @@ class IdentityReflectionTest(fixtures.TablesTest):
if col["name"] == "normal":
is_false("identity" in col)
elif col["name"] == "id1":
- is_true(col["autoincrement"] in (True, "auto"))
+ if "autoincrement" in col:
+ is_true(col["autoincrement"])
eq_(col["default"], None)
is_true("identity" in col)
self.check(
@@ -1659,7 +2727,8 @@ class IdentityReflectionTest(fixtures.TablesTest):
approx=True,
)
elif col["name"] == "id2":
- is_true(col["autoincrement"] in (True, "auto"))
+ if "autoincrement" in col:
+ is_true(col["autoincrement"])
eq_(col["default"], None)
is_true("identity" in col)
self.check(
@@ -1685,7 +2754,8 @@ class IdentityReflectionTest(fixtures.TablesTest):
if col["name"] == "normal":
is_false("identity" in col)
elif col["name"] == "id1":
- is_true(col["autoincrement"] in (True, "auto"))
+ if "autoincrement" in col:
+ is_true(col["autoincrement"])
eq_(col["default"], None)
is_true("identity" in col)
self.check(
@@ -1735,16 +2805,16 @@ class CompositeKeyReflectionTest(fixtures.TablesTest):
)
@testing.requires.primary_key_constraint_reflection
- def test_pk_column_order(self):
+ def test_pk_column_order(self, connection):
# test for issue #5661
- insp = inspect(self.bind)
+ insp = inspect(connection)
primary_key = insp.get_pk_constraint(self.tables.tb1.name)
eq_(primary_key.get("constrained_columns"), ["name", "id", "attr"])
@testing.requires.foreign_key_constraint_reflection
- def test_fk_column_order(self):
+ def test_fk_column_order(self, connection):
# test for issue #5661
- insp = inspect(self.bind)
+ insp = inspect(connection)
foreign_keys = insp.get_foreign_keys(self.tables.tb2.name)
eq_(len(foreign_keys), 1)
fkey1 = foreign_keys[0]
diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py
index eae051992..e15fad642 100644
--- a/lib/sqlalchemy/testing/suite/test_sequence.py
+++ b/lib/sqlalchemy/testing/suite/test_sequence.py
@@ -194,16 +194,23 @@ class HasSequenceTest(fixtures.TablesTest):
)
def test_has_sequence(self, connection):
- eq_(
- inspect(connection).has_sequence("user_id_seq"),
- True,
- )
+ eq_(inspect(connection).has_sequence("user_id_seq"), True)
+
+ def test_has_sequence_cache(self, connection, metadata):
+ insp = inspect(connection)
+ eq_(insp.has_sequence("user_id_seq"), True)
+ ss = Sequence("new_seq", metadata=metadata)
+ eq_(insp.has_sequence("new_seq"), False)
+ ss.create(connection)
+ try:
+ eq_(insp.has_sequence("new_seq"), False)
+ insp.clear_cache()
+ eq_(insp.has_sequence("new_seq"), True)
+ finally:
+ ss.drop(connection)
def test_has_sequence_other_object(self, connection):
- eq_(
- inspect(connection).has_sequence("user_id_table"),
- False,
- )
+ eq_(inspect(connection).has_sequence("user_id_table"), False)
@testing.requires.schemas
def test_has_sequence_schema(self, connection):
@@ -215,10 +222,7 @@ class HasSequenceTest(fixtures.TablesTest):
)
def test_has_sequence_neg(self, connection):
- eq_(
- inspect(connection).has_sequence("some_sequence"),
- False,
- )
+ eq_(inspect(connection).has_sequence("some_sequence"), False)
@testing.requires.schemas
def test_has_sequence_schemas_neg(self, connection):
@@ -240,10 +244,7 @@ class HasSequenceTest(fixtures.TablesTest):
@testing.requires.schemas
def test_has_sequence_remote_not_in_default(self, connection):
- eq_(
- inspect(connection).has_sequence("schema_seq"),
- False,
- )
+ eq_(inspect(connection).has_sequence("schema_seq"), False)
def test_get_sequence_names(self, connection):
exp = {"other_seq", "user_id_seq"}
diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py
index 0070b4d67..6fd42af70 100644
--- a/lib/sqlalchemy/testing/util.py
+++ b/lib/sqlalchemy/testing/util.py
@@ -393,36 +393,55 @@ def drop_all_tables_from_metadata(metadata, engine_or_connection):
go(engine_or_connection)
-def drop_all_tables(engine, inspector, schema=None, include_names=None):
+def drop_all_tables(
+ engine,
+ inspector,
+ schema=None,
+ consider_schemas=(None,),
+ include_names=None,
+):
if include_names is not None:
include_names = set(include_names)
+ if schema is not None:
+ assert consider_schemas == (
+ None,
+ ), "consider_schemas and schema are mutually exclusive"
+ consider_schemas = (schema,)
+
with engine.begin() as conn:
- for tname, fkcs in reversed(
- inspector.get_sorted_table_and_fkc_names(schema=schema)
+ for table_key, fkcs in reversed(
+ inspector.sort_tables_on_foreign_key_dependency(
+ consider_schemas=consider_schemas
+ )
):
- if tname:
- if include_names is not None and tname not in include_names:
+ if table_key:
+ if (
+ include_names is not None
+ and table_key[1] not in include_names
+ ):
continue
conn.execute(
- DropTable(Table(tname, MetaData(), schema=schema))
+ DropTable(
+ Table(table_key[1], MetaData(), schema=table_key[0])
+ )
)
elif fkcs:
if not engine.dialect.supports_alter:
continue
- for tname, fkc in fkcs:
+ for t_key, fkc in fkcs:
if (
include_names is not None
- and tname not in include_names
+ and t_key[1] not in include_names
):
continue
tb = Table(
- tname,
+ t_key[1],
MetaData(),
Column("x", Integer),
Column("y", Integer),
- schema=schema,
+ schema=t_key[0],
)
conn.execute(
DropConstraint(
diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py
index 24e478b57..620e3bbb7 100644
--- a/lib/sqlalchemy/util/topological.py
+++ b/lib/sqlalchemy/util/topological.py
@@ -10,6 +10,7 @@
from __future__ import annotations
from typing import Any
+from typing import Collection
from typing import DefaultDict
from typing import Iterable
from typing import Iterator
@@ -27,7 +28,7 @@ __all__ = ["sort", "sort_as_subsets", "find_cycles"]
def sort_as_subsets(
- tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T]
+ tuples: Collection[Tuple[_T, _T]], allitems: Collection[_T]
) -> Iterator[Sequence[_T]]:
edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set)
@@ -56,8 +57,8 @@ def sort_as_subsets(
def sort(
- tuples: Iterable[Tuple[_T, _T]],
- allitems: Iterable[_T],
+ tuples: Collection[Tuple[_T, _T]],
+ allitems: Collection[_T],
deterministic_order: bool = True,
) -> Iterator[_T]:
"""sort the given list of items by dependency.
@@ -76,8 +77,7 @@ def sort(
def find_cycles(
- tuples: Iterable[Tuple[_T, _T]],
- allitems: Iterable[_T],
+ tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T]
) -> Set[_T]:
# adapted from:
# https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html
diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py
index eb625e06e..4e76554c7 100644
--- a/lib/sqlalchemy/util/typing.py
+++ b/lib/sqlalchemy/util/typing.py
@@ -78,11 +78,13 @@ if typing.TYPE_CHECKING or compat.py38:
from typing import Protocol as Protocol
from typing import TypedDict as TypedDict
from typing import Final as Final
+ from typing import final as final
else:
from typing_extensions import Literal as Literal # noqa: F401
from typing_extensions import Protocol as Protocol # noqa: F401
from typing_extensions import TypedDict as TypedDict # noqa: F401
from typing_extensions import Final as Final # noqa: F401
+ from typing_extensions import final as final # noqa: F401
typing_get_args = get_args
typing_get_origin = get_origin