diff options
48 files changed, 9522 insertions, 3223 deletions
diff --git a/README.unittests.rst b/README.unittests.rst index 0e59d8e92..23f7292cc 100644 --- a/README.unittests.rst +++ b/README.unittests.rst @@ -310,8 +310,13 @@ be used with pytest by using ``--db docker_mssql``. >> sqlplus system/tiger@//localhost/XEPDB1 <<EOF CREATE USER test_schema IDENTIFIED BY tiger; GRANT DBA TO SCOTT; + GRANT CREATE TABLE TO scott; + GRANT CREATE TABLE TO test_schema; GRANT UNLIMITED TABLESPACE TO scott; GRANT UNLIMITED TABLESPACE TO test_schema; + GRANT CREATE SESSION TO test_schema; + CREATE PUBLIC DATABASE LINK test_link CONNECT TO scott IDENTIFIED BY tiger USING 'XEPDB1'; + CREATE PUBLIC DATABASE LINK test_link2 CONNECT TO test_schema IDENTIFIED BY tiger USING 'XEPDB1'; EOF # To stop the container. It will also remove it. diff --git a/doc/build/changelog/unreleased_20/4379.rst b/doc/build/changelog/unreleased_20/4379.rst new file mode 100644 index 000000000..e761988b7 --- /dev/null +++ b/doc/build/changelog/unreleased_20/4379.rst @@ -0,0 +1,69 @@ +.. change:: + :tags: performance, schema + :tickets: 4379 + + Rearchitected the schema reflection API to allow some dialects to make use + of high performing batch queries to reflect the schemas of many tables at + once using much fewer queries. The new performance features are targeted + first at the PostgreSQL and Oracle backends, and may be applied to any + dialect that makes use of SELECT queries against system catalog tables to + reflect tables (currently this omits the MySQL and SQLite dialects which + instead make use of parsing the "CREATE TABLE" statement, however these + dialects do not have a pre-existing performance issue with reflection. MS + SQL Server is still a TODO). + + The new API is backwards compatible with the previous system, and should + require no changes to third party dialects to retain compatibility; + third party dialects can also opt into the new system by implementing + batched queries for schema reflection. + + Along with this change is an updated reflection API that is fully + :pep:`484` typed, features many new methods and some changes. + +.. change:: + :tags: bug, schema + + For SQLAlchemy-included dialects for SQLite, PostgreSQL, MySQL/MariaDB, + Oracle, and SQL Server, the :meth:`.Inspector.has_table`, + :meth:`.Inspector.has_sequence`, :meth:`.Inspector.has_index`, + :meth:`.Inspector.get_table_names` and + :meth:`.Inspector.get_sequence_names` now all behave consistently in terms + of caching: they all fully cache their result after being called the first + time for a particular :class:`.Inspector` object. Programs that create or + drop tables/sequences while calling upon the same :class:`.Inspector` + object will not receive updated status after the state of the database has + changed. A call to :meth:`.Inspector.clear_cache` or a new + :class:`.Inspector` should be used when DDL changes are to be executed. + Previously, the :meth:`.Inspector.has_table`, + :meth:`.Inspector.has_sequence` methods did not implement caching nor did + the :class:`.Inspector` support caching for these methods, while the + :meth:`.Inspector.get_table_names` and + :meth:`.Inspector.get_sequence_names` methods were, leading to inconsistent + results between the two types of method. + + Behavior for third party dialects is dependent on whether or not they + implement the "reflection cache" decorator for the dialect-level + implementation of these methods. + +.. change:: + :tags: change, schema + :tickets: 4379 + + Improvements to the :class:`.Inspector` object: + + * added a method + :meth:`.Inspector.has_schema` that returns if a schema + is present in the target database + * added a method :meth:`.Inspector.has_index` that returns if a table has + a particular index. + * Inspection methods such as :meth:`.Inspector.get_columns` that work + on a single table at a time should now all consistently + raise :class:`_exc.NoSuchTableError` if a + table or view is not found; this change is specific to individual + dialects, so may not be the case for existing third-party dialects. + * Separated the handling of "views" and "materialized views", as in + real world use cases, these two constructs make use of different DDL + for CREATE and DROP; this includes that there are now separate + :meth:`.Inspector.get_view_names` and + :meth:`.Inspector.get_materialized_view_names` methods. + diff --git a/doc/build/changelog/unreleased_20/oracle_views.rst b/doc/build/changelog/unreleased_20/oracle_views.rst new file mode 100644 index 000000000..61a509010 --- /dev/null +++ b/doc/build/changelog/unreleased_20/oracle_views.rst @@ -0,0 +1,11 @@ +.. change:: + :tags: change, oracle + :tickets:`4379` + + Materialized views on oracle are now reflected as views. + On previous versions of SQLAlchemy the views were returned among + the table names, not among the view names. As a side effect of + this change they are not reflected by default by + :meth:`_sql.MetaData.reflect`, unless ``views=True`` is set. + To get a list of materialized views, use the new + inspection method :meth:`.Inspector.get_materialized_view_names`. diff --git a/doc/build/changelog/unreleased_20/postgresql_version.rst b/doc/build/changelog/unreleased_20/postgresql_version.rst new file mode 100644 index 000000000..112da08f7 --- /dev/null +++ b/doc/build/changelog/unreleased_20/postgresql_version.rst @@ -0,0 +1,5 @@ +.. change:: + :tags: change, postgresql + + SQLAlchemy now requires PostgreSQL version 9 or greater. + Older versions may still work in some limited use cases. diff --git a/doc/build/tutorial/metadata.rst b/doc/build/tutorial/metadata.rst index 6444ed692..e3b257be6 100644 --- a/doc/build/tutorial/metadata.rst +++ b/doc/build/tutorial/metadata.rst @@ -515,7 +515,7 @@ using the :paramref:`_schema.Table.autoload_with` parameter: {opensql}BEGIN (implicit) PRAGMA main.table_...info("some_table") [raw sql] () - SELECT sql FROM (SELECT * FROM sqlite_master UNION ALL SELECT * FROM sqlite_temp_master) WHERE name = ? AND type = 'table' + SELECT sql FROM (SELECT * FROM sqlite_master UNION ALL SELECT * FROM sqlite_temp_master) WHERE name = ? AND type in ('table', 'view') [raw sql] ('some_table',) PRAGMA main.foreign_key_list("some_table") ... diff --git a/lib/sqlalchemy/dialects/mssql/base.py b/lib/sqlalchemy/dialects/mssql/base.py index 12f495d6e..2a4362ccb 100644 --- a/lib/sqlalchemy/dialects/mssql/base.py +++ b/lib/sqlalchemy/dialects/mssql/base.py @@ -831,6 +831,7 @@ from ... import util from ...engine import cursor as _cursor from ...engine import default from ...engine import reflection +from ...engine.reflection import ReflectionDefaults from ...sql import coercions from ...sql import compiler from ...sql import elements @@ -3010,55 +3011,16 @@ class MSDialect(default.DefaultDialect): return self.schema_name @_db_plus_owner - def has_table(self, connection, tablename, dbname, owner, schema): + def has_table(self, connection, tablename, dbname, owner, schema, **kw): self._ensure_has_table_connection(connection) - if tablename.startswith("#"): # temporary table - # mssql does not support temporary views - # SQL Error [4103] [S0001]: "#v": Temporary views are not allowed - tables = ischema.mssql_temp_table_columns - s = sql.select(tables.c.table_name).where( - tables.c.table_name.like( - self._temp_table_name_like_pattern(tablename) - ) - ) - - # #7168: fetch all (not just first match) in case some other #temp - # table with the same name happens to appear first - table_names = connection.execute(s).scalars().fetchall() - # #6910: verify it's not a temp table from another session - for table_name in table_names: - if bool( - connection.scalar( - text("SELECT object_id(:table_name)"), - {"table_name": "tempdb.dbo.[{}]".format(table_name)}, - ) - ): - return True - else: - return False - else: - tables = ischema.tables - - s = sql.select(tables.c.table_name).where( - sql.and_( - sql.or_( - tables.c.table_type == "BASE TABLE", - tables.c.table_type == "VIEW", - ), - tables.c.table_name == tablename, - ) - ) - - if owner: - s = s.where(tables.c.table_schema == owner) - - c = connection.execute(s) - - return c.first() is not None + return self._internal_has_table(connection, tablename, owner, **kw) + @reflection.cache @_db_plus_owner - def has_sequence(self, connection, sequencename, dbname, owner, schema): + def has_sequence( + self, connection, sequencename, dbname, owner, schema, **kw + ): sequences = ischema.sequences s = sql.select(sequences.c.sequence_name).where( @@ -3128,6 +3090,60 @@ class MSDialect(default.DefaultDialect): return view_names @reflection.cache + def _internal_has_table(self, connection, tablename, owner, **kw): + if tablename.startswith("#"): # temporary table + # mssql does not support temporary views + # SQL Error [4103] [S0001]: "#v": Temporary views are not allowed + tables = ischema.mssql_temp_table_columns + + s = sql.select(tables.c.table_name).where( + tables.c.table_name.like( + self._temp_table_name_like_pattern(tablename) + ) + ) + + # #7168: fetch all (not just first match) in case some other #temp + # table with the same name happens to appear first + table_names = connection.scalars(s).all() + # #6910: verify it's not a temp table from another session + for table_name in table_names: + if bool( + connection.scalar( + text("SELECT object_id(:table_name)"), + {"table_name": "tempdb.dbo.[{}]".format(table_name)}, + ) + ): + return True + else: + return False + else: + tables = ischema.tables + + s = sql.select(tables.c.table_name).where( + sql.and_( + sql.or_( + tables.c.table_type == "BASE TABLE", + tables.c.table_type == "VIEW", + ), + tables.c.table_name == tablename, + ) + ) + + if owner: + s = s.where(tables.c.table_schema == owner) + + c = connection.execute(s) + + return c.first() is not None + + def _default_or_error(self, connection, tablename, owner, method, **kw): + # TODO: try to avoid having to run a separate query here + if self._internal_has_table(connection, tablename, owner, **kw): + return method() + else: + raise exc.NoSuchTableError(f"{owner}.{tablename}") + + @reflection.cache @_db_plus_owner def get_indexes(self, connection, tablename, dbname, owner, schema, **kw): filter_definition = ( @@ -3138,14 +3154,14 @@ class MSDialect(default.DefaultDialect): rp = connection.execution_options(future_result=True).execute( sql.text( "select ind.index_id, ind.is_unique, ind.name, " - "%s " + f"{filter_definition} " "from sys.indexes as ind join sys.tables as tab on " "ind.object_id=tab.object_id " "join sys.schemas as sch on sch.schema_id=tab.schema_id " "where tab.name = :tabname " "and sch.name=:schname " - "and ind.is_primary_key=0 and ind.type != 0" - % filter_definition + "and ind.is_primary_key=0 and ind.type != 0 " + "order by ind.name " ) .bindparams( sql.bindparam("tabname", tablename, ischema.CoerceUnicode()), @@ -3203,31 +3219,34 @@ class MSDialect(default.DefaultDialect): "mssql_include" ] = index_info["include_columns"] - return list(indexes.values()) + if indexes: + return list(indexes.values()) + else: + return self._default_or_error( + connection, tablename, owner, ReflectionDefaults.indexes, **kw + ) @reflection.cache @_db_plus_owner def get_view_definition( self, connection, viewname, dbname, owner, schema, **kw ): - rp = connection.execute( + view_def = connection.execute( sql.text( - "select definition from sys.sql_modules as mod, " - "sys.views as views, " - "sys.schemas as sch" - " where " - "mod.object_id=views.object_id and " - "views.schema_id=sch.schema_id and " - "views.name=:viewname and sch.name=:schname" + "select mod.definition " + "from sys.sql_modules as mod " + "join sys.views as views on mod.object_id = views.object_id " + "join sys.schemas as sch on views.schema_id = sch.schema_id " + "where views.name=:viewname and sch.name=:schname" ).bindparams( sql.bindparam("viewname", viewname, ischema.CoerceUnicode()), sql.bindparam("schname", owner, ischema.CoerceUnicode()), ) - ) - - if rp: - view_def = rp.scalar() + ).scalar() + if view_def: return view_def + else: + raise exc.NoSuchTableError(f"{owner}.{viewname}") def _temp_table_name_like_pattern(self, tablename): # LIKE uses '%' to match zero or more characters and '_' to match any @@ -3417,7 +3436,12 @@ class MSDialect(default.DefaultDialect): cols.append(cdict) - return cols + if cols: + return cols + else: + return self._default_or_error( + connection, tablename, owner, ReflectionDefaults.columns, **kw + ) @reflection.cache @_db_plus_owner @@ -3450,7 +3474,16 @@ class MSDialect(default.DefaultDialect): pkeys.append(row["COLUMN_NAME"]) if constraint_name is None: constraint_name = row[C.c.constraint_name.name] - return {"constrained_columns": pkeys, "name": constraint_name} + if pkeys: + return {"constrained_columns": pkeys, "name": constraint_name} + else: + return self._default_or_error( + connection, + tablename, + owner, + ReflectionDefaults.pk_constraint, + **kw, + ) @reflection.cache @_db_plus_owner @@ -3591,7 +3624,7 @@ index_info AS ( fkeys = util.defaultdict(fkey_rec) - for r in connection.execute(s).fetchall(): + for r in connection.execute(s).all(): ( _, # constraint schema rfknm, @@ -3632,4 +3665,13 @@ index_info AS ( local_cols.append(scol) remote_cols.append(rcol) - return list(fkeys.values()) + if fkeys: + return list(fkeys.values()) + else: + return self._default_or_error( + connection, + tablename, + owner, + ReflectionDefaults.foreign_keys, + **kw, + ) diff --git a/lib/sqlalchemy/dialects/mysql/base.py b/lib/sqlalchemy/dialects/mysql/base.py index 7c9a68236..502371be9 100644 --- a/lib/sqlalchemy/dialects/mysql/base.py +++ b/lib/sqlalchemy/dialects/mysql/base.py @@ -1056,6 +1056,7 @@ from ... import sql from ... import util from ...engine import default from ...engine import reflection +from ...engine.reflection import ReflectionDefaults from ...sql import coercions from ...sql import compiler from ...sql import elements @@ -2648,7 +2649,8 @@ class MySQLDialect(default.DefaultDialect): def _get_default_schema_name(self, connection): return connection.exec_driver_sql("SELECT DATABASE()").scalar() - def has_table(self, connection, table_name, schema=None): + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): self._ensure_has_table_connection(connection) if schema is None: @@ -2670,7 +2672,8 @@ class MySQLDialect(default.DefaultDialect): ) return bool(rs.scalar()) - def has_sequence(self, connection, sequence_name, schema=None): + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): if not self.supports_sequences: self._sequences_not_supported() if not schema: @@ -2847,14 +2850,20 @@ class MySQLDialect(default.DefaultDialect): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) - return parsed_state.table_options + if parsed_state.table_options: + return parsed_state.table_options + else: + return ReflectionDefaults.table_options() @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) - return parsed_state.columns + if parsed_state.columns: + return parsed_state.columns + else: + return ReflectionDefaults.columns() @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): @@ -2866,7 +2875,7 @@ class MySQLDialect(default.DefaultDialect): # There can be only one. cols = [s[0] for s in key["columns"]] return {"constrained_columns": cols, "name": None} - return {"constrained_columns": [], "name": None} + return ReflectionDefaults.pk_constraint() @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): @@ -2909,7 +2918,7 @@ class MySQLDialect(default.DefaultDialect): if self._needs_correct_for_88718_96365: self._correct_for_mysql_bugs_88718_96365(fkeys, connection) - return fkeys + return fkeys if fkeys else ReflectionDefaults.foreign_keys() def _correct_for_mysql_bugs_88718_96365(self, fkeys, connection): # Foreign key is always in lower case (MySQL 8.0) @@ -3000,21 +3009,22 @@ class MySQLDialect(default.DefaultDialect): connection, table_name, schema, **kw ) - return [ + cks = [ {"name": spec["name"], "sqltext": spec["sqltext"]} for spec in parsed_state.ck_constraints ] + return cks if cks else ReflectionDefaults.check_constraints() @reflection.cache def get_table_comment(self, connection, table_name, schema=None, **kw): parsed_state = self._parsed_state_or_create( connection, table_name, schema, **kw ) - return { - "text": parsed_state.table_options.get( - "%s_comment" % self.name, None - ) - } + comment = parsed_state.table_options.get(f"{self.name}_comment", None) + if comment is not None: + return {"text": comment} + else: + return ReflectionDefaults.table_comment() @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): @@ -3058,7 +3068,8 @@ class MySQLDialect(default.DefaultDialect): if flavor: index_d["type"] = flavor indexes.append(index_d) - return indexes + indexes.sort(key=lambda d: d["name"] or "~") # sort None as last + return indexes if indexes else ReflectionDefaults.indexes() @reflection.cache def get_unique_constraints( @@ -3068,7 +3079,7 @@ class MySQLDialect(default.DefaultDialect): connection, table_name, schema, **kw ) - return [ + ucs = [ { "name": key["name"], "column_names": [col[0] for col in key["columns"]], @@ -3077,6 +3088,11 @@ class MySQLDialect(default.DefaultDialect): for key in parsed_state.keys if key["type"] == "UNIQUE" ] + ucs.sort(key=lambda d: d["name"] or "~") # sort None as last + if ucs: + return ucs + else: + return ReflectionDefaults.unique_constraints() @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): @@ -3088,6 +3104,9 @@ class MySQLDialect(default.DefaultDialect): sql = self._show_create_table( connection, None, charset, full_name=full_name ) + if sql.upper().startswith("CREATE TABLE"): + # it's a table, not a view + raise exc.NoSuchTableError(full_name) return sql def _parsed_state_or_create( diff --git a/lib/sqlalchemy/dialects/oracle/base.py b/lib/sqlalchemy/dialects/oracle/base.py index faac0deb7..fee098889 100644 --- a/lib/sqlalchemy/dialects/oracle/base.py +++ b/lib/sqlalchemy/dialects/oracle/base.py @@ -518,21 +518,52 @@ columns for non-unique indexes, all but the last column for unique indexes). """ # noqa -from itertools import groupby +from __future__ import annotations + +from collections import defaultdict +from functools import lru_cache +from functools import wraps import re +from . import dictionary +from .types import _OracleBoolean +from .types import _OracleDate +from .types import BFILE +from .types import BINARY_DOUBLE +from .types import BINARY_FLOAT +from .types import DATE +from .types import FLOAT +from .types import INTERVAL +from .types import LONG +from .types import NCLOB +from .types import NUMBER +from .types import NVARCHAR2 # noqa +from .types import OracleRaw # noqa +from .types import RAW +from .types import ROWID # noqa +from .types import VARCHAR2 # noqa from ... import Computed from ... import exc from ... import schema as sa_schema from ... import sql from ... import util from ...engine import default +from ...engine import ObjectKind +from ...engine import ObjectScope from ...engine import reflection +from ...engine.reflection import ReflectionDefaults +from ...sql import and_ +from ...sql import bindparam from ...sql import compiler from ...sql import expression +from ...sql import func +from ...sql import null +from ...sql import or_ +from ...sql import select from ...sql import sqltypes from ...sql import util as sql_util from ...sql import visitors +from ...sql.visitors import InternalTraversal from ...types import BLOB from ...types import CHAR from ...types import CLOB @@ -561,229 +592,6 @@ NO_ARG_FNS = set( ) -class RAW(sqltypes._Binary): - __visit_name__ = "RAW" - - -OracleRaw = RAW - - -class NCLOB(sqltypes.Text): - __visit_name__ = "NCLOB" - - -class VARCHAR2(VARCHAR): - __visit_name__ = "VARCHAR2" - - -NVARCHAR2 = NVARCHAR - - -class NUMBER(sqltypes.Numeric, sqltypes.Integer): - __visit_name__ = "NUMBER" - - def __init__(self, precision=None, scale=None, asdecimal=None): - if asdecimal is None: - asdecimal = bool(scale and scale > 0) - - super(NUMBER, self).__init__( - precision=precision, scale=scale, asdecimal=asdecimal - ) - - def adapt(self, impltype): - ret = super(NUMBER, self).adapt(impltype) - # leave a hint for the DBAPI handler - ret._is_oracle_number = True - return ret - - @property - def _type_affinity(self): - if bool(self.scale and self.scale > 0): - return sqltypes.Numeric - else: - return sqltypes.Integer - - -class FLOAT(sqltypes.FLOAT): - """Oracle FLOAT. - - This is the same as :class:`_sqltypes.FLOAT` except that - an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision` - parameter is accepted, and - the :paramref:`_sqltypes.Float.precision` parameter is not accepted. - - Oracle FLOAT types indicate precision in terms of "binary precision", which - defaults to 126. For a REAL type, the value is 63. This parameter does not - cleanly map to a specific number of decimal places but is roughly - equivalent to the desired number of decimal places divided by 0.3103. - - .. versionadded:: 2.0 - - """ - - __visit_name__ = "FLOAT" - - def __init__( - self, - binary_precision=None, - asdecimal=False, - decimal_return_scale=None, - ): - r""" - Construct a FLOAT - - :param binary_precision: Oracle binary precision value to be rendered - in DDL. This may be approximated to the number of decimal characters - using the formula "decimal precision = 0.30103 * binary precision". - The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126. - - :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal` - - :param decimal_return_scale: See - :paramref:`_sqltypes.Float.decimal_return_scale` - - """ - super().__init__( - asdecimal=asdecimal, decimal_return_scale=decimal_return_scale - ) - self.binary_precision = binary_precision - - -class BINARY_DOUBLE(sqltypes.Float): - __visit_name__ = "BINARY_DOUBLE" - - -class BINARY_FLOAT(sqltypes.Float): - __visit_name__ = "BINARY_FLOAT" - - -class BFILE(sqltypes.LargeBinary): - __visit_name__ = "BFILE" - - -class LONG(sqltypes.Text): - __visit_name__ = "LONG" - - -class _OracleDateLiteralRender: - def _literal_processor_datetime(self, dialect): - def process(value): - if value is not None: - if getattr(value, "microsecond", None): - value = ( - f"""TO_TIMESTAMP""" - f"""('{value.isoformat().replace("T", " ")}', """ - """'YYYY-MM-DD HH24:MI:SS.FF')""" - ) - else: - value = ( - f"""TO_DATE""" - f"""('{value.isoformat().replace("T", " ")}', """ - """'YYYY-MM-DD HH24:MI:SS')""" - ) - return value - - return process - - def _literal_processor_date(self, dialect): - def process(value): - if value is not None: - if getattr(value, "microsecond", None): - value = ( - f"""TO_TIMESTAMP""" - f"""('{value.isoformat().split("T")[0]}', """ - """'YYYY-MM-DD')""" - ) - else: - value = ( - f"""TO_DATE""" - f"""('{value.isoformat().split("T")[0]}', """ - """'YYYY-MM-DD')""" - ) - return value - - return process - - -class DATE(_OracleDateLiteralRender, sqltypes.DateTime): - """Provide the oracle DATE type. - - This type has no special Python behavior, except that it subclasses - :class:`_types.DateTime`; this is to suit the fact that the Oracle - ``DATE`` type supports a time value. - - .. versionadded:: 0.9.4 - - """ - - __visit_name__ = "DATE" - - def literal_processor(self, dialect): - return self._literal_processor_datetime(dialect) - - def _compare_type_affinity(self, other): - return other._type_affinity in (sqltypes.DateTime, sqltypes.Date) - - -class _OracleDate(_OracleDateLiteralRender, sqltypes.Date): - def literal_processor(self, dialect): - return self._literal_processor_date(dialect) - - -class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): - __visit_name__ = "INTERVAL" - - def __init__(self, day_precision=None, second_precision=None): - """Construct an INTERVAL. - - Note that only DAY TO SECOND intervals are currently supported. - This is due to a lack of support for YEAR TO MONTH intervals - within available DBAPIs. - - :param day_precision: the day precision value. this is the number of - digits to store for the day field. Defaults to "2" - :param second_precision: the second precision value. this is the - number of digits to store for the fractional seconds field. - Defaults to "6". - - """ - self.day_precision = day_precision - self.second_precision = second_precision - - @classmethod - def _adapt_from_generic_interval(cls, interval): - return INTERVAL( - day_precision=interval.day_precision, - second_precision=interval.second_precision, - ) - - @property - def _type_affinity(self): - return sqltypes.Interval - - def as_generic(self, allow_nulltype=False): - return sqltypes.Interval( - native=True, - second_precision=self.second_precision, - day_precision=self.day_precision, - ) - - -class ROWID(sqltypes.TypeEngine): - """Oracle ROWID type. - - When used in a cast() or similar, generates ROWID. - - """ - - __visit_name__ = "ROWID" - - -class _OracleBoolean(sqltypes.Boolean): - def get_dbapi_type(self, dbapi): - return dbapi.NUMBER - - colspecs = { sqltypes.Boolean: _OracleBoolean, sqltypes.Interval: INTERVAL, @@ -1541,6 +1349,13 @@ class OracleExecutionContext(default.DefaultExecutionContext): type_, ) + def pre_exec(self): + if self.statement and "_oracle_dblink" in self.execution_options: + self.statement = self.statement.replace( + dictionary.DB_LINK_PLACEHOLDER, + self.execution_options["_oracle_dblink"], + ) + class OracleDialect(default.DefaultDialect): name = "oracle" @@ -1675,6 +1490,10 @@ class OracleDialect(default.DefaultDialect): # it may work also on versions before the 18 return self.server_version_info and self.server_version_info >= (18,) + @property + def _supports_except_all(self): + return self.server_version_info and self.server_version_info >= (21,) + def do_release_savepoint(self, connection, name): # Oracle does not support RELEASE SAVEPOINT pass @@ -1700,45 +1519,99 @@ class OracleDialect(default.DefaultDialect): except: return "READ COMMITTED" - def has_table(self, connection, table_name, schema=None): + def _execute_reflection( + self, connection, query, dblink, returns_long, params=None + ): + if dblink and not dblink.startswith("@"): + dblink = f"@{dblink}" + execution_options = { + # handle db links + "_oracle_dblink": dblink or "", + # override any schema translate map + "schema_translate_map": None, + } + + if dblink and returns_long: + # Oracle seems to error with + # "ORA-00997: illegal use of LONG datatype" when returning + # LONG columns via a dblink in a query with bind params + # This type seems to be very hard to cast into something else + # so it seems easier to just use bind param in this case + def visit_bindparam(bindparam): + bindparam.literal_execute = True + + query = visitors.cloned_traverse( + query, {}, {"bindparam": visit_bindparam} + ) + return connection.execute( + query, params, execution_options=execution_options + ) + + @util.memoized_property + def _has_table_query(self): + # materialized views are returned by all_tables + tables = ( + select( + dictionary.all_tables.c.table_name, + dictionary.all_tables.c.owner, + ) + .union_all( + select( + dictionary.all_views.c.view_name.label("table_name"), + dictionary.all_views.c.owner, + ) + ) + .subquery("tables_and_views") + ) + + query = select(tables.c.table_name).where( + tables.c.table_name == bindparam("table_name"), + tables.c.owner == bindparam("owner"), + ) + return query + + @reflection.cache + def has_table( + self, connection, table_name, schema=None, dblink=None, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" self._ensure_has_table_connection(connection) if not schema: schema = self.default_schema_name - cursor = connection.execute( - sql.text( - """SELECT table_name FROM all_tables - WHERE table_name = CAST(:name AS VARCHAR2(128)) - AND owner = CAST(:schema_name AS VARCHAR2(128)) - UNION ALL - SELECT view_name FROM all_views - WHERE view_name = CAST(:name AS VARCHAR2(128)) - AND owner = CAST(:schema_name AS VARCHAR2(128)) - """ - ), - dict( - name=self.denormalize_name(table_name), - schema_name=self.denormalize_name(schema), - ), + params = { + "table_name": self.denormalize_name(table_name), + "owner": self.denormalize_name(schema), + } + cursor = self._execute_reflection( + connection, + self._has_table_query, + dblink, + returns_long=False, + params=params, ) - return cursor.first() is not None + return bool(cursor.scalar()) - def has_sequence(self, connection, sequence_name, schema=None): + @reflection.cache + def has_sequence( + self, connection, sequence_name, schema=None, dblink=None, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" if not schema: schema = self.default_schema_name - cursor = connection.execute( - sql.text( - "SELECT sequence_name FROM all_sequences " - "WHERE sequence_name = :name AND " - "sequence_owner = :schema_name" - ), - dict( - name=self.denormalize_name(sequence_name), - schema_name=self.denormalize_name(schema), - ), + + query = select(dictionary.all_sequences.c.sequence_name).where( + dictionary.all_sequences.c.sequence_name + == self.denormalize_name(sequence_name), + dictionary.all_sequences.c.sequence_owner + == self.denormalize_name(schema), ) - return cursor.first() is not None + + cursor = self._execute_reflection( + connection, query, dblink, returns_long=False + ) + return bool(cursor.scalar()) def _get_default_schema_name(self, connection): return self.normalize_name( @@ -1747,329 +1620,633 @@ class OracleDialect(default.DefaultDialect): ).scalar() ) - def _resolve_synonym( - self, - connection, - desired_owner=None, - desired_synonym=None, - desired_table=None, + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("filter_names", InternalTraversal.dp_string_list), + ("dblink", InternalTraversal.dp_string), + ) + def _get_synonyms(self, connection, schema, filter_names, dblink, **kw): + owner = self.denormalize_name(schema or self.default_schema_name) + + has_filter_names, params = self._prepare_filter_names(filter_names) + query = select( + dictionary.all_synonyms.c.synonym_name, + dictionary.all_synonyms.c.table_name, + dictionary.all_synonyms.c.table_owner, + dictionary.all_synonyms.c.db_link, + ).where(dictionary.all_synonyms.c.owner == owner) + if has_filter_names: + query = query.where( + dictionary.all_synonyms.c.synonym_name.in_( + params["filter_names"] + ) + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).mappings() + return result.all() + + @lru_cache() + def _all_objects_query( + self, owner, scope, kind, has_filter_names, has_mat_views ): - """search for a local synonym matching the given desired owner/name. - - if desired_owner is None, attempts to locate a distinct owner. - - returns the actual name, owner, dblink name, and synonym name if - found. - """ - - q = ( - "SELECT owner, table_owner, table_name, db_link, " - "synonym_name FROM all_synonyms WHERE " + query = ( + select(dictionary.all_objects.c.object_name) + .select_from(dictionary.all_objects) + .where(dictionary.all_objects.c.owner == owner) ) - clauses = [] - params = {} - if desired_synonym: - clauses.append( - "synonym_name = CAST(:synonym_name AS VARCHAR2(128))" + + # NOTE: materialized views are listed in all_objects twice; + # once as MATERIALIZE VIEW and once as TABLE + if kind is ObjectKind.ANY: + # materilaized view are listed also as tables so there is no + # need to add them to the in_. + query = query.where( + dictionary.all_objects.c.object_type.in_(("TABLE", "VIEW")) ) - params["synonym_name"] = desired_synonym - if desired_owner: - clauses.append("owner = CAST(:desired_owner AS VARCHAR2(128))") - params["desired_owner"] = desired_owner - if desired_table: - clauses.append("table_name = CAST(:tname AS VARCHAR2(128))") - params["tname"] = desired_table - - q += " AND ".join(clauses) - - result = connection.execution_options(future_result=True).execute( - sql.text(q), params - ) - if desired_owner: - row = result.mappings().first() - if row: - return ( - row["table_name"], - row["table_owner"], - row["db_link"], - row["synonym_name"], - ) - else: - return None, None, None, None else: - rows = result.mappings().all() - if len(rows) > 1: - raise AssertionError( - "There are multiple tables visible to the schema, you " - "must specify owner" - ) - elif len(rows) == 1: - row = rows[0] - return ( - row["table_name"], - row["table_owner"], - row["db_link"], - row["synonym_name"], - ) - else: - return None, None, None, None + object_type = [] + if ObjectKind.VIEW in kind: + object_type.append("VIEW") + if ( + ObjectKind.MATERIALIZED_VIEW in kind + and ObjectKind.TABLE not in kind + ): + # materilaized view are listed also as tables so there is no + # need to add them to the in_ if also selecting tables. + object_type.append("MATERIALIZED VIEW") + if ObjectKind.TABLE in kind: + object_type.append("TABLE") + if has_mat_views and ObjectKind.MATERIALIZED_VIEW not in kind: + # materialized view are listed also as tables, + # so they need to be filtered out + # EXCEPT ALL / MINUS profiles as faster than using + # NOT EXISTS or NOT IN with a subquery, but it's in + # general faster to get the mat view names and exclude + # them only when needed + query = query.where( + dictionary.all_objects.c.object_name.not_in( + bindparam("mat_views") + ) + ) + query = query.where( + dictionary.all_objects.c.object_type.in_(object_type) + ) - @reflection.cache - def _prepare_reflection_args( - self, - connection, - table_name, - schema=None, - resolve_synonyms=False, - dblink="", - **kw, - ): + # handles scope + if scope is ObjectScope.DEFAULT: + query = query.where(dictionary.all_objects.c.temporary == "N") + elif scope is ObjectScope.TEMPORARY: + query = query.where(dictionary.all_objects.c.temporary == "Y") - if resolve_synonyms: - actual_name, owner, dblink, synonym = self._resolve_synonym( - connection, - desired_owner=self.denormalize_name(schema), - desired_synonym=self.denormalize_name(table_name), + if has_filter_names: + query = query.where( + dictionary.all_objects.c.object_name.in_( + bindparam("filter_names") + ) ) - else: - actual_name, owner, dblink, synonym = None, None, None, None - if not actual_name: - actual_name = self.denormalize_name(table_name) - - if dblink: - # using user_db_links here since all_db_links appears - # to have more restricted permissions. - # https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm - # will need to hear from more users if we are doing - # the right thing here. See [ticket:2619] - owner = connection.scalar( - sql.text( - "SELECT username FROM user_db_links " "WHERE db_link=:link" - ), - dict(link=dblink), + return query + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("scope", InternalTraversal.dp_plain_obj), + ("kind", InternalTraversal.dp_plain_obj), + ("filter_names", InternalTraversal.dp_string_list), + ("dblink", InternalTraversal.dp_string), + ) + def _get_all_objects( + self, connection, schema, scope, kind, filter_names, dblink, **kw + ): + owner = self.denormalize_name(schema or self.default_schema_name) + + has_filter_names, params = self._prepare_filter_names(filter_names) + has_mat_views = False + if ( + ObjectKind.TABLE in kind + and ObjectKind.MATERIALIZED_VIEW not in kind + ): + # see note in _all_objects_query + mat_views = self.get_materialized_view_names( + connection, schema, dblink, _normalize=False, **kw ) - dblink = "@" + dblink - elif not owner: - owner = self.denormalize_name(schema or self.default_schema_name) + if mat_views: + params["mat_views"] = mat_views + has_mat_views = True + + query = self._all_objects_query( + owner, scope, kind, has_filter_names, has_mat_views + ) - return (actual_name, owner, dblink or "", synonym) + result = self._execute_reflection( + connection, query, dblink, returns_long=False, params=params + ).scalars() - @reflection.cache - def get_schema_names(self, connection, **kw): - s = "SELECT username FROM all_users ORDER BY username" - cursor = connection.exec_driver_sql(s) - return [self.normalize_name(row[0]) for row in cursor] + return result.all() + + def _handle_synonyms_decorator(fn): + @wraps(fn) + def wrapper(self, *args, **kwargs): + return self._handle_synonyms(fn, *args, **kwargs) + + return wrapper + + def _handle_synonyms(self, fn, connection, *args, **kwargs): + if not kwargs.get("oracle_resolve_synonyms", False): + return fn(self, connection, *args, **kwargs) + + original_kw = kwargs.copy() + schema = kwargs.pop("schema", None) + result = self._get_synonyms( + connection, + schema=schema, + filter_names=kwargs.pop("filter_names", None), + dblink=kwargs.pop("dblink", None), + info_cache=kwargs.get("info_cache", None), + ) + + dblinks_owners = defaultdict(dict) + for row in result: + key = row["db_link"], row["table_owner"] + tn = self.normalize_name(row["table_name"]) + dblinks_owners[key][tn] = row["synonym_name"] + + if not dblinks_owners: + # No synonym, do the plain thing + return fn(self, connection, *args, **original_kw) + + data = {} + for (dblink, table_owner), mapping in dblinks_owners.items(): + call_kw = { + **original_kw, + "schema": table_owner, + "dblink": self.normalize_name(dblink), + "filter_names": mapping.keys(), + } + call_result = fn(self, connection, *args, **call_kw) + for (_, tn), value in call_result: + synonym_name = self.normalize_name(mapping[tn]) + data[(schema, synonym_name)] = value + return data.items() @reflection.cache - def get_table_names(self, connection, schema=None, **kw): - schema = self.denormalize_name(schema or self.default_schema_name) + def get_schema_names(self, connection, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + query = select(dictionary.all_users.c.username).order_by( + dictionary.all_users.c.username + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] + @reflection.cache + def get_table_names(self, connection, schema=None, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" # note that table_names() isn't loading DBLINKed or synonym'ed tables if schema is None: schema = self.default_schema_name - sql_str = "SELECT table_name FROM all_tables WHERE " + den_schema = self.denormalize_name(schema) + if kw.get("oracle_resolve_synonyms", False): + tables = ( + select( + dictionary.all_tables.c.table_name, + dictionary.all_tables.c.owner, + dictionary.all_tables.c.iot_name, + dictionary.all_tables.c.duration, + dictionary.all_tables.c.tablespace_name, + ) + .union_all( + select( + dictionary.all_synonyms.c.synonym_name.label( + "table_name" + ), + dictionary.all_synonyms.c.owner, + dictionary.all_tables.c.iot_name, + dictionary.all_tables.c.duration, + dictionary.all_tables.c.tablespace_name, + ) + .select_from(dictionary.all_tables) + .join( + dictionary.all_synonyms, + and_( + dictionary.all_tables.c.table_name + == dictionary.all_synonyms.c.table_name, + dictionary.all_tables.c.owner + == func.coalesce( + dictionary.all_synonyms.c.table_owner, + dictionary.all_synonyms.c.owner, + ), + ), + ) + ) + .subquery("available_tables") + ) + else: + tables = dictionary.all_tables + + query = select(tables.c.table_name) if self.exclude_tablespaces: - sql_str += ( - "nvl(tablespace_name, 'no tablespace') " - "NOT IN (%s) AND " - % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces])) + query = query.where( + func.coalesce( + tables.c.tablespace_name, "no tablespace" + ).not_in(self.exclude_tablespaces) ) - sql_str += ( - "OWNER = :owner " "AND IOT_NAME IS NULL " "AND DURATION IS NULL" + query = query.where( + tables.c.owner == den_schema, + tables.c.iot_name.is_(null()), + tables.c.duration.is_(null()), ) - cursor = connection.execute(sql.text(sql_str), dict(owner=schema)) - return [self.normalize_name(row[0]) for row in cursor] + # remove materialized views + mat_query = select( + dictionary.all_mviews.c.mview_name.label("table_name") + ).where(dictionary.all_mviews.c.owner == den_schema) + + query = ( + query.except_all(mat_query) + if self._supports_except_all + else query.except_(mat_query) + ) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] @reflection.cache - def get_temp_table_names(self, connection, **kw): + def get_temp_table_names(self, connection, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" schema = self.denormalize_name(self.default_schema_name) - sql_str = "SELECT table_name FROM all_tables WHERE " + query = select(dictionary.all_tables.c.table_name) if self.exclude_tablespaces: - sql_str += ( - "nvl(tablespace_name, 'no tablespace') " - "NOT IN (%s) AND " - % (", ".join(["'%s'" % ts for ts in self.exclude_tablespaces])) + query = query.where( + func.coalesce( + dictionary.all_tables.c.tablespace_name, "no tablespace" + ).not_in(self.exclude_tablespaces) ) - sql_str += ( - "OWNER = :owner " - "AND IOT_NAME IS NULL " - "AND DURATION IS NOT NULL" + query = query.where( + dictionary.all_tables.c.owner == schema, + dictionary.all_tables.c.iot_name.is_(null()), + dictionary.all_tables.c.duration.is_not(null()), ) - cursor = connection.execute(sql.text(sql_str), dict(owner=schema)) - return [self.normalize_name(row[0]) for row in cursor] + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] @reflection.cache - def get_view_names(self, connection, schema=None, **kw): - schema = self.denormalize_name(schema or self.default_schema_name) - s = sql.text("SELECT view_name FROM all_views WHERE owner = :owner") - cursor = connection.execute( - s, dict(owner=self.denormalize_name(schema)) + def get_materialized_view_names( + self, connection, schema=None, dblink=None, _normalize=True, **kw + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + if not schema: + schema = self.default_schema_name + + query = select(dictionary.all_mviews.c.mview_name).where( + dictionary.all_mviews.c.owner == self.denormalize_name(schema) ) - return [self.normalize_name(row[0]) for row in cursor] + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + if _normalize: + return [self.normalize_name(row) for row in result] + else: + return result.all() @reflection.cache - def get_sequence_names(self, connection, schema=None, **kw): + def get_view_names(self, connection, schema=None, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" if not schema: schema = self.default_schema_name - cursor = connection.execute( - sql.text( - "SELECT sequence_name FROM all_sequences " - "WHERE sequence_owner = :schema_name" - ), - dict(schema_name=self.denormalize_name(schema)), + + query = select(dictionary.all_views.c.view_name).where( + dictionary.all_views.c.owner == self.denormalize_name(schema) ) - return [self.normalize_name(row[0]) for row in cursor] + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] @reflection.cache - def get_table_options(self, connection, table_name, schema=None, **kw): - options = {} + def get_sequence_names(self, connection, schema=None, dblink=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link.""" + if not schema: + schema = self.default_schema_name + query = select(dictionary.all_sequences.c.sequence_name).where( + dictionary.all_sequences.c.sequence_owner + == self.denormalize_name(schema) + ) - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(row) for row in result] - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + def _value_or_raise(self, data, table, schema): + table = self.normalize_name(str(table)) + try: + return dict(data)[(schema, table)] + except KeyError: + raise exc.NoSuchTableError( + f"{schema}.{table}" if schema else table + ) from None + + def _prepare_filter_names(self, filter_names): + if filter_names: + fn = [self.denormalize_name(name) for name in filter_names] + return True, {"filter_names": fn} + else: + return False, {} + + @reflection.cache + def get_table_options(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_table_options( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - params = {"table_name": table_name} + @lru_cache() + def _table_options_query( + self, owner, scope, kind, has_filter_names, has_mat_views + ): + query = select( + dictionary.all_tables.c.table_name, + dictionary.all_tables.c.compression, + dictionary.all_tables.c.compress_for, + ).where(dictionary.all_tables.c.owner == owner) + if has_filter_names: + query = query.where( + dictionary.all_tables.c.table_name.in_( + bindparam("filter_names") + ) + ) + if scope is ObjectScope.DEFAULT: + query = query.where(dictionary.all_tables.c.duration.is_(null())) + elif scope is ObjectScope.TEMPORARY: + query = query.where( + dictionary.all_tables.c.duration.is_not(null()) + ) - columns = ["table_name"] - if self._supports_table_compression: - columns.append("compression") - if self._supports_table_compress_for: - columns.append("compress_for") + if ( + has_mat_views + and ObjectKind.TABLE in kind + and ObjectKind.MATERIALIZED_VIEW not in kind + ): + # cant use EXCEPT ALL / MINUS here because we don't have an + # excludable row vs. the query above + # outerjoin + where null works better on oracle 21 but 11 does + # not like it at all. this is the next best thing + + query = query.where( + dictionary.all_tables.c.table_name.not_in( + bindparam("mat_views") + ) + ) + elif ( + ObjectKind.TABLE not in kind + and ObjectKind.MATERIALIZED_VIEW in kind + ): + query = query.where( + dictionary.all_tables.c.table_name.in_(bindparam("mat_views")) + ) + return query - text = ( - "SELECT %(columns)s " - "FROM ALL_TABLES%(dblink)s " - "WHERE table_name = CAST(:table_name AS VARCHAR(128))" - ) + @_handle_synonyms_decorator + def get_multi_table_options( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + owner = self.denormalize_name(schema or self.default_schema_name) - if schema is not None: - params["owner"] = schema - text += " AND owner = CAST(:owner AS VARCHAR(128)) " - text = text % {"dblink": dblink, "columns": ", ".join(columns)} + has_filter_names, params = self._prepare_filter_names(filter_names) + has_mat_views = False - result = connection.execute(sql.text(text), params) + if ( + ObjectKind.TABLE in kind + and ObjectKind.MATERIALIZED_VIEW not in kind + ): + # see note in _table_options_query + mat_views = self.get_materialized_view_names( + connection, schema, dblink, _normalize=False, **kw + ) + if mat_views: + params["mat_views"] = mat_views + has_mat_views = True + elif ( + ObjectKind.TABLE not in kind + and ObjectKind.MATERIALIZED_VIEW in kind + ): + mat_views = self.get_materialized_view_names( + connection, schema, dblink, _normalize=False, **kw + ) + params["mat_views"] = mat_views - enabled = dict(DISABLED=False, ENABLED=True) + options = {} + default = ReflectionDefaults.table_options - row = result.first() - if row: - if "compression" in row._fields and enabled.get( - row.compression, False - ): - if "compress_for" in row._fields: - options["oracle_compress"] = row.compress_for + if ObjectKind.TABLE in kind or ObjectKind.MATERIALIZED_VIEW in kind: + query = self._table_options_query( + owner, scope, kind, has_filter_names, has_mat_views + ) + result = self._execute_reflection( + connection, query, dblink, returns_long=False, params=params + ) + + for table, compression, compress_for in result: + if compression == "ENABLED": + data = {"oracle_compress": compress_for} else: - options["oracle_compress"] = True + data = default() + options[(schema, self.normalize_name(table))] = data + if ObjectKind.VIEW in kind and ObjectScope.DEFAULT in scope: + # add the views (no temporary views) + for view in self.get_view_names(connection, schema, dblink, **kw): + if not filter_names or view in filter_names: + options[(schema, view)] = default() - return options + return options.items() @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms """ - kw arguments can be: + data = self.get_multi_columns( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + def _run_batches( + self, connection, query, dblink, returns_long, mappings, all_objects + ): + each_batch = 500 + batches = list(all_objects) + while batches: + batch = batches[0:each_batch] + batches[0:each_batch] = [] + + result = self._execute_reflection( + connection, + query, + dblink, + returns_long=returns_long, + params={"all_objects": batch}, + ) + if mappings: + yield from result.mappings() + else: + yield from result + + @lru_cache() + def _column_query(self, owner): + all_cols = dictionary.all_tab_cols + all_comments = dictionary.all_col_comments + all_ids = dictionary.all_tab_identity_cols - oracle_resolve_synonyms + if self.server_version_info >= (12,): + add_cols = ( + all_cols.c.default_on_null, + sql.case( + (all_ids.c.table_name.is_(None), sql.null()), + else_=all_ids.c.generation_type + + "," + + all_ids.c.identity_options, + ).label("identity_options"), + ) + join_identity_cols = True + else: + add_cols = ( + sql.null().label("default_on_null"), + sql.null().label("identity_options"), + ) + join_identity_cols = False + + # NOTE: on oracle cannot create tables/views without columns and + # a table cannot have all column hidden: + # ORA-54039: table must have at least one column that is not invisible + # all_tab_cols returns data for tables/views/mat-views. + # all_tab_cols does not return recycled tables + + query = ( + select( + all_cols.c.table_name, + all_cols.c.column_name, + all_cols.c.data_type, + all_cols.c.char_length, + all_cols.c.data_precision, + all_cols.c.data_scale, + all_cols.c.nullable, + all_cols.c.data_default, + all_comments.c.comments, + all_cols.c.virtual_column, + *add_cols, + ).select_from(all_cols) + # NOTE: all_col_comments has a row for each column even if no + # comment is present, so a join could be performed, but there + # seems to be no difference compared to an outer join + .outerjoin( + all_comments, + and_( + all_cols.c.table_name == all_comments.c.table_name, + all_cols.c.column_name == all_comments.c.column_name, + all_cols.c.owner == all_comments.c.owner, + ), + ) + ) + if join_identity_cols: + query = query.outerjoin( + all_ids, + and_( + all_cols.c.table_name == all_ids.c.table_name, + all_cols.c.column_name == all_ids.c.column_name, + all_cols.c.owner == all_ids.c.owner, + ), + ) - dblink + query = query.where( + all_cols.c.table_name.in_(bindparam("all_objects")), + all_cols.c.hidden_column == "NO", + all_cols.c.owner == owner, + ).order_by(all_cols.c.table_name, all_cols.c.column_id) + return query + @_handle_synonyms_decorator + def get_multi_columns( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms """ + owner = self.denormalize_name(schema or self.default_schema_name) + query = self._column_query(owner) - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") + if ( + filter_names + and kind is ObjectKind.ANY + and scope is ObjectScope.ANY + ): + all_objects = [self.denormalize_name(n) for n in filter_names] + else: + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + columns = defaultdict(list) + + # all_tab_cols.data_default is LONG + result = self._run_batches( connection, - table_name, - schema, - resolve_synonyms, + query, dblink, - info_cache=info_cache, + returns_long=True, + mappings=True, + all_objects=all_objects, ) - columns = [] - if self._supports_char_length: - char_length_col = "char_length" - else: - char_length_col = "data_length" - if self.server_version_info >= (12,): - identity_cols = """\ - col.default_on_null, - ( - SELECT id.generation_type || ',' || id.IDENTITY_OPTIONS - FROM ALL_TAB_IDENTITY_COLS%(dblink)s id - WHERE col.table_name = id.table_name - AND col.column_name = id.column_name - AND col.owner = id.owner - ) AS identity_options""" % { - "dblink": dblink - } - else: - identity_cols = "NULL as default_on_null, NULL as identity_options" - - params = {"table_name": table_name} - - text = """ - SELECT - col.column_name, - col.data_type, - col.%(char_length_col)s, - col.data_precision, - col.data_scale, - col.nullable, - col.data_default, - com.comments, - col.virtual_column, - %(identity_cols)s - FROM all_tab_cols%(dblink)s col - LEFT JOIN all_col_comments%(dblink)s com - ON col.table_name = com.table_name - AND col.column_name = com.column_name - AND col.owner = com.owner - WHERE col.table_name = CAST(:table_name AS VARCHAR2(128)) - AND col.hidden_column = 'NO' - """ - if schema is not None: - params["owner"] = schema - text += " AND col.owner = :owner " - text += " ORDER BY col.column_id" - text = text % { - "dblink": dblink, - "char_length_col": char_length_col, - "identity_cols": identity_cols, - } - - c = connection.execute(sql.text(text), params) - - for row in c: - colname = self.normalize_name(row[0]) - orig_colname = row[0] - coltype = row[1] - length = row[2] - precision = row[3] - scale = row[4] - nullable = row[5] == "Y" - default = row[6] - comment = row[7] - generated = row[8] - default_on_nul = row[9] - identity_options = row[10] + for row_dict in result: + table_name = self.normalize_name(row_dict["table_name"]) + orig_colname = row_dict["column_name"] + colname = self.normalize_name(orig_colname) + coltype = row_dict["data_type"] + precision = row_dict["data_precision"] if coltype == "NUMBER": + scale = row_dict["data_scale"] if precision is None and scale == 0: coltype = INTEGER() else: @@ -2089,7 +2266,9 @@ class OracleDialect(default.DefaultDialect): coltype = FLOAT(binary_precision=precision) elif coltype in ("VARCHAR2", "NVARCHAR2", "CHAR", "NCHAR"): - coltype = self.ischema_names.get(coltype)(length) + coltype = self.ischema_names.get(coltype)( + row_dict["char_length"] + ) elif "WITH TIME ZONE" in coltype: coltype = TIMESTAMP(timezone=True) else: @@ -2103,15 +2282,17 @@ class OracleDialect(default.DefaultDialect): ) coltype = sqltypes.NULLTYPE - if generated == "YES": + default = row_dict["data_default"] + if row_dict["virtual_column"] == "YES": computed = dict(sqltext=default) default = None else: computed = None + identity_options = row_dict["identity_options"] if identity_options is not None: identity = self._parse_identity_options( - identity_options, default_on_nul + identity_options, row_dict["default_on_null"] ) default = None else: @@ -2120,10 +2301,9 @@ class OracleDialect(default.DefaultDialect): cdict = { "name": colname, "type": coltype, - "nullable": nullable, + "nullable": row_dict["nullable"] == "Y", "default": default, - "autoincrement": "auto", - "comment": comment, + "comment": row_dict["comments"], } if orig_colname.lower() == orig_colname: cdict["quote"] = True @@ -2132,10 +2312,17 @@ class OracleDialect(default.DefaultDialect): if identity is not None: cdict["identity"] = identity - columns.append(cdict) - return columns + columns[(schema, table_name)].append(cdict) - def _parse_identity_options(self, identity_options, default_on_nul): + # NOTE: default not needed since all tables have columns + # default = ReflectionDefaults.columns + # return ( + # (key, value if value else default()) + # for key, value in columns.items() + # ) + return columns.items() + + def _parse_identity_options(self, identity_options, default_on_null): # identity_options is a string that starts with 'ALWAYS,' or # 'BY DEFAULT,' and continues with # START WITH: 1, INCREMENT BY: 1, MAX_VALUE: 123, MIN_VALUE: 1, @@ -2144,7 +2331,7 @@ class OracleDialect(default.DefaultDialect): parts = [p.strip() for p in identity_options.split(",")] identity = { "always": parts[0] == "ALWAYS", - "on_null": default_on_nul == "YES", + "on_null": default_on_null == "YES", } for part in parts[1:]: @@ -2168,384 +2355,641 @@ class OracleDialect(default.DefaultDialect): return identity @reflection.cache - def get_table_comment( + def get_table_comment(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_table_comment( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _comment_query(self, owner, scope, kind, has_filter_names): + # NOTE: all_tab_comments / all_mview_comments have a row for all + # object even if they don't have comments + queries = [] + if ObjectKind.TABLE in kind or ObjectKind.VIEW in kind: + # all_tab_comments returns also plain views + tbl_view = select( + dictionary.all_tab_comments.c.table_name, + dictionary.all_tab_comments.c.comments, + ).where( + dictionary.all_tab_comments.c.owner == owner, + dictionary.all_tab_comments.c.table_name.not_like("BIN$%"), + ) + if ObjectKind.VIEW not in kind: + tbl_view = tbl_view.where( + dictionary.all_tab_comments.c.table_type == "TABLE" + ) + elif ObjectKind.TABLE not in kind: + tbl_view = tbl_view.where( + dictionary.all_tab_comments.c.table_type == "VIEW" + ) + queries.append(tbl_view) + if ObjectKind.MATERIALIZED_VIEW in kind: + mat_view = select( + dictionary.all_mview_comments.c.mview_name.label("table_name"), + dictionary.all_mview_comments.c.comments, + ).where( + dictionary.all_mview_comments.c.owner == owner, + dictionary.all_mview_comments.c.mview_name.not_like("BIN$%"), + ) + queries.append(mat_view) + if len(queries) == 1: + query = queries[0] + else: + union = sql.union_all(*queries).subquery("tables_and_views") + query = select(union.c.table_name, union.c.comments) + + name_col = query.selected_columns.table_name + + if scope in (ObjectScope.DEFAULT, ObjectScope.TEMPORARY): + temp = "Y" if scope is ObjectScope.TEMPORARY else "N" + # need distinct since materialized view are listed also + # as tables in all_objects + query = query.distinct().join( + dictionary.all_objects, + and_( + dictionary.all_objects.c.owner == owner, + dictionary.all_objects.c.object_name == name_col, + dictionary.all_objects.c.temporary == temp, + ), + ) + if has_filter_names: + query = query.where(name_col.in_(bindparam("filter_names"))) + return query + + @_handle_synonyms_decorator + def get_multi_table_comment( self, connection, - table_name, - schema=None, - resolve_synonyms=False, - dblink="", + *, + schema, + filter_names, + scope, + kind, + dblink=None, **kw, ): - - info_cache = kw.get("info_cache") - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( - connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, - ) - - if not schema: - schema = self.default_schema_name - - COMMENT_SQL = """ - SELECT comments - FROM all_tab_comments - WHERE table_name = CAST(:table_name AS VARCHAR(128)) - AND owner = CAST(:schema_name AS VARCHAR(128)) + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms """ + owner = self.denormalize_name(schema or self.default_schema_name) + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._comment_query(owner, scope, kind, has_filter_names) - c = connection.execute( - sql.text(COMMENT_SQL), - dict(table_name=table_name, schema_name=schema), + result = self._execute_reflection( + connection, query, dblink, returns_long=False, params=params + ) + default = ReflectionDefaults.table_comment + # materialized views by default seem to have a comment like + # "snapshot table for snapshot owner.mat_view_name" + ignore_mat_view = "snapshot table for snapshot " + return ( + ( + (schema, self.normalize_name(table)), + {"text": comment} + if comment is not None + and not comment.startswith(ignore_mat_view) + else default(), + ) + for table, comment in result ) - return {"text": c.scalar()} @reflection.cache - def get_indexes( - self, - connection, - table_name, - schema=None, - resolve_synonyms=False, - dblink="", - **kw, - ): - - info_cache = kw.get("info_cache") - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + def get_indexes(self, connection, table_name, schema=None, **kw): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_indexes( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) - indexes = [] - - params = {"table_name": table_name} - text = ( - "SELECT a.index_name, a.column_name, " - "\nb.index_type, b.uniqueness, b.compression, b.prefix_length " - "\nFROM ALL_IND_COLUMNS%(dblink)s a, " - "\nALL_INDEXES%(dblink)s b " - "\nWHERE " - "\na.index_name = b.index_name " - "\nAND a.table_owner = b.table_owner " - "\nAND a.table_name = b.table_name " - "\nAND a.table_name = CAST(:table_name AS VARCHAR(128))" + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _index_query(self, owner): + return ( + select( + dictionary.all_ind_columns.c.table_name, + dictionary.all_ind_columns.c.index_name, + dictionary.all_ind_columns.c.column_name, + dictionary.all_indexes.c.index_type, + dictionary.all_indexes.c.uniqueness, + dictionary.all_indexes.c.compression, + dictionary.all_indexes.c.prefix_length, + ) + .select_from(dictionary.all_ind_columns) + .join( + dictionary.all_indexes, + sql.and_( + dictionary.all_ind_columns.c.index_name + == dictionary.all_indexes.c.index_name, + dictionary.all_ind_columns.c.table_owner + == dictionary.all_indexes.c.table_owner, + # NOTE: this condition on table_name is not required + # but it improves the query performance noticeably + dictionary.all_ind_columns.c.table_name + == dictionary.all_indexes.c.table_name, + ), + ) + .where( + dictionary.all_ind_columns.c.table_owner == owner, + dictionary.all_ind_columns.c.table_name.in_( + bindparam("all_objects") + ), + ) + .order_by( + dictionary.all_ind_columns.c.index_name, + dictionary.all_ind_columns.c.column_position, + ) ) - if schema is not None: - params["schema"] = schema - text += "AND a.table_owner = :schema " + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("dblink", InternalTraversal.dp_string), + ("all_objects", InternalTraversal.dp_string_list), + ) + def _get_indexes_rows(self, connection, schema, dblink, all_objects, **kw): + owner = self.denormalize_name(schema or self.default_schema_name) - text += "ORDER BY a.index_name, a.column_position" + query = self._index_query(owner) - text = text % {"dblink": dblink} + pks = { + row_dict["constraint_name"] + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ) + if row_dict["constraint_type"] == "P" + } - q = sql.text(text) - rp = connection.execute(q, params) - indexes = [] - last_index_name = None - pk_constraint = self.get_pk_constraint( + result = self._run_batches( connection, - table_name, - schema, - resolve_synonyms=resolve_synonyms, - dblink=dblink, - info_cache=kw.get("info_cache"), + query, + dblink, + returns_long=False, + mappings=True, + all_objects=all_objects, ) - uniqueness = dict(NONUNIQUE=False, UNIQUE=True) - enabled = dict(DISABLED=False, ENABLED=True) + return [ + row_dict + for row_dict in result + if row_dict["index_name"] not in pks + ] - oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE) + @_handle_synonyms_decorator + def get_multi_indexes( + self, + connection, + *, + schema, + filter_names, + scope, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) - index = None - for rset in rp: - index_name_normalized = self.normalize_name(rset.index_name) + uniqueness = {"NONUNIQUE": False, "UNIQUE": True} + enabled = {"DISABLED": False, "ENABLED": True} + is_bitmap = {"BITMAP", "FUNCTION-BASED BITMAP"} - # skip primary key index. This is refined as of - # [ticket:5421]. Note that ALL_INDEXES.GENERATED will by "Y" - # if the name of this index was generated by Oracle, however - # if a named primary key constraint was created then this flag - # is false. - if ( - pk_constraint - and index_name_normalized == pk_constraint["name"] - ): - continue + oracle_sys_col = re.compile(r"SYS_NC\d+\$", re.IGNORECASE) - if rset.index_name != last_index_name: - index = dict( - name=index_name_normalized, - column_names=[], - dialect_options={}, - ) - indexes.append(index) - index["unique"] = uniqueness.get(rset.uniqueness, False) + indexes = defaultdict(dict) + + for row_dict in self._get_indexes_rows( + connection, schema, dblink, all_objects, **kw + ): + index_name = self.normalize_name(row_dict["index_name"]) + table_name = self.normalize_name(row_dict["table_name"]) + table_indexes = indexes[(schema, table_name)] + + if index_name not in table_indexes: + table_indexes[index_name] = index_dict = { + "name": index_name, + "column_names": [], + "dialect_options": {}, + "unique": uniqueness.get(row_dict["uniqueness"], False), + } + do = index_dict["dialect_options"] + if row_dict["index_type"] in is_bitmap: + do["oracle_bitmap"] = True + if enabled.get(row_dict["compression"], False): + do["oracle_compress"] = row_dict["prefix_length"] - if rset.index_type in ("BITMAP", "FUNCTION-BASED BITMAP"): - index["dialect_options"]["oracle_bitmap"] = True - if enabled.get(rset.compression, False): - index["dialect_options"][ - "oracle_compress" - ] = rset.prefix_length + else: + index_dict = table_indexes[index_name] # filter out Oracle SYS_NC names. could also do an outer join - # to the all_tab_columns table and check for real col names there. - if not oracle_sys_col.match(rset.column_name): - index["column_names"].append( - self.normalize_name(rset.column_name) + # to the all_tab_columns table and check for real col names + # there. + if not oracle_sys_col.match(row_dict["column_name"]): + index_dict["column_names"].append( + self.normalize_name(row_dict["column_name"]) ) - last_index_name = rset.index_name - return indexes + default = ReflectionDefaults.indexes - @reflection.cache - def _get_constraint_data( - self, connection, table_name, schema=None, dblink="", **kw - ): - - params = {"table_name": table_name} - - text = ( - "SELECT" - "\nac.constraint_name," # 0 - "\nac.constraint_type," # 1 - "\nloc.column_name AS local_column," # 2 - "\nrem.table_name AS remote_table," # 3 - "\nrem.column_name AS remote_column," # 4 - "\nrem.owner AS remote_owner," # 5 - "\nloc.position as loc_pos," # 6 - "\nrem.position as rem_pos," # 7 - "\nac.search_condition," # 8 - "\nac.delete_rule" # 9 - "\nFROM all_constraints%(dblink)s ac," - "\nall_cons_columns%(dblink)s loc," - "\nall_cons_columns%(dblink)s rem" - "\nWHERE ac.table_name = CAST(:table_name AS VARCHAR2(128))" - "\nAND ac.constraint_type IN ('R','P', 'U', 'C')" - ) - - if schema is not None: - params["owner"] = schema - text += "\nAND ac.owner = CAST(:owner AS VARCHAR2(128))" - - text += ( - "\nAND ac.owner = loc.owner" - "\nAND ac.constraint_name = loc.constraint_name" - "\nAND ac.r_owner = rem.owner(+)" - "\nAND ac.r_constraint_name = rem.constraint_name(+)" - "\nAND (rem.position IS NULL or loc.position=rem.position)" - "\nORDER BY ac.constraint_name, loc.position" + return ( + (key, list(indexes[key].values()) if key in indexes else default()) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) ) - text = text % {"dblink": dblink} - rp = connection.execute(sql.text(text), params) - constraint_data = rp.fetchall() - return constraint_data - @reflection.cache def get_pk_constraint(self, connection, table_name, schema=None, **kw): - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") - - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_pk_constraint( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) - pkeys = [] - constraint_name = None - constraint_data = self._get_constraint_data( - connection, - table_name, - schema, - dblink, - info_cache=kw.get("info_cache"), + return self._value_or_raise(data, table_name, schema) + + @lru_cache() + def _constraint_query(self, owner): + local = dictionary.all_cons_columns.alias("local") + remote = dictionary.all_cons_columns.alias("remote") + return ( + select( + dictionary.all_constraints.c.table_name, + dictionary.all_constraints.c.constraint_type, + dictionary.all_constraints.c.constraint_name, + local.c.column_name.label("local_column"), + remote.c.table_name.label("remote_table"), + remote.c.column_name.label("remote_column"), + remote.c.owner.label("remote_owner"), + dictionary.all_constraints.c.search_condition, + dictionary.all_constraints.c.delete_rule, + ) + .select_from(dictionary.all_constraints) + .join( + local, + and_( + local.c.owner == dictionary.all_constraints.c.owner, + dictionary.all_constraints.c.constraint_name + == local.c.constraint_name, + ), + ) + .outerjoin( + remote, + and_( + dictionary.all_constraints.c.r_owner == remote.c.owner, + dictionary.all_constraints.c.r_constraint_name + == remote.c.constraint_name, + or_( + remote.c.position.is_(sql.null()), + local.c.position == remote.c.position, + ), + ), + ) + .where( + dictionary.all_constraints.c.owner == owner, + dictionary.all_constraints.c.table_name.in_( + bindparam("all_objects") + ), + dictionary.all_constraints.c.constraint_type.in_( + ("R", "P", "U", "C") + ), + ) + .order_by( + dictionary.all_constraints.c.constraint_name, local.c.position + ) ) - for row in constraint_data: - ( - cons_name, - cons_type, - local_column, - remote_table, - remote_column, - remote_owner, - ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) - if cons_type == "P": - if constraint_name is None: - constraint_name = self.normalize_name(cons_name) - pkeys.append(local_column) - return {"constrained_columns": pkeys, "name": constraint_name} + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("dblink", InternalTraversal.dp_string), + ("all_objects", InternalTraversal.dp_string_list), + ) + def _get_all_constraint_rows( + self, connection, schema, dblink, all_objects, **kw + ): + owner = self.denormalize_name(schema or self.default_schema_name) + query = self._constraint_query(owner) - @reflection.cache - def get_foreign_keys(self, connection, table_name, schema=None, **kw): + # since the result is cached a list must be created + values = list( + self._run_batches( + connection, + query, + dblink, + returns_long=False, + mappings=True, + all_objects=all_objects, + ) + ) + return values + + @_handle_synonyms_decorator + def get_multi_pk_constraint( + self, + connection, + *, + scope, + schema, + filter_names, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw + ) - kw arguments can be: + primary_keys = defaultdict(dict) + default = ReflectionDefaults.pk_constraint - oracle_resolve_synonyms + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "P": + continue + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name = self.normalize_name(row_dict["constraint_name"]) + column_name = self.normalize_name(row_dict["local_column"]) + + table_pk = primary_keys[(schema, table_name)] + if not table_pk: + table_pk["name"] = constraint_name + table_pk["constrained_columns"] = [column_name] + else: + table_pk["constrained_columns"].append(column_name) - dblink + return ( + (key, primary_keys[key] if key in primary_keys else default()) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) + @reflection.cache + def get_foreign_keys( + self, + connection, + table_name, + schema=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms """ - requested_schema = schema # to check later on - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") - - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + data = self.get_multi_foreign_keys( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - constraint_data = self._get_constraint_data( - connection, - table_name, - schema, - dblink, - info_cache=kw.get("info_cache"), + @_handle_synonyms_decorator + def get_multi_foreign_keys( + self, + connection, + *, + scope, + schema, + filter_names, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw ) - def fkey_rec(): - return { - "name": None, - "constrained_columns": [], - "referred_schema": None, - "referred_table": None, - "referred_columns": [], - "options": {}, - } + resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - fkeys = util.defaultdict(fkey_rec) + owner = self.denormalize_name(schema or self.default_schema_name) - for row in constraint_data: - ( - cons_name, - cons_type, - local_column, - remote_table, - remote_column, - remote_owner, - ) = row[0:2] + tuple([self.normalize_name(x) for x in row[2:6]]) - - cons_name = self.normalize_name(cons_name) - - if cons_type == "R": - if remote_table is None: - # ticket 363 - util.warn( - ( - "Got 'None' querying 'table_name' from " - "all_cons_columns%(dblink)s - does the user have " - "proper rights to the table?" - ) - % {"dblink": dblink} - ) - continue + all_remote_owners = set() + fkeys = defaultdict(dict) + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "R": + continue + + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name = self.normalize_name(row_dict["constraint_name"]) + table_fkey = fkeys[(schema, table_name)] + + assert constraint_name is not None - rec = fkeys[cons_name] - rec["name"] = cons_name - local_cols, remote_cols = ( - rec["constrained_columns"], - rec["referred_columns"], + local_column = self.normalize_name(row_dict["local_column"]) + remote_table = self.normalize_name(row_dict["remote_table"]) + remote_column = self.normalize_name(row_dict["remote_column"]) + remote_owner_orig = row_dict["remote_owner"] + remote_owner = self.normalize_name(remote_owner_orig) + if remote_owner_orig is not None: + all_remote_owners.add(remote_owner_orig) + + if remote_table is None: + # ticket 363 + if dblink and not dblink.startswith("@"): + dblink = f"@{dblink}" + util.warn( + "Got 'None' querying 'table_name' from " + f"all_cons_columns{dblink or ''} - does the user have " + "proper rights to the table?" ) + continue - if not rec["referred_table"]: - if resolve_synonyms: - ( - ref_remote_name, - ref_remote_owner, - ref_dblink, - ref_synonym, - ) = self._resolve_synonym( - connection, - desired_owner=self.denormalize_name(remote_owner), - desired_table=self.denormalize_name(remote_table), - ) - if ref_synonym: - remote_table = self.normalize_name(ref_synonym) - remote_owner = self.normalize_name( - ref_remote_owner - ) + if constraint_name not in table_fkey: + table_fkey[constraint_name] = fkey = { + "name": constraint_name, + "constrained_columns": [], + "referred_schema": None, + "referred_table": remote_table, + "referred_columns": [], + "options": {}, + } - rec["referred_table"] = remote_table + if resolve_synonyms: + # will be removed below + fkey["_ref_schema"] = remote_owner - if ( - requested_schema is not None - or self.denormalize_name(remote_owner) != schema - ): - rec["referred_schema"] = remote_owner + if schema is not None or remote_owner_orig != owner: + fkey["referred_schema"] = remote_owner + + delete_rule = row_dict["delete_rule"] + if delete_rule != "NO ACTION": + fkey["options"]["ondelete"] = delete_rule + + else: + fkey = table_fkey[constraint_name] + + fkey["constrained_columns"].append(local_column) + fkey["referred_columns"].append(remote_column) + + if resolve_synonyms and all_remote_owners: + query = select( + dictionary.all_synonyms.c.owner, + dictionary.all_synonyms.c.table_name, + dictionary.all_synonyms.c.table_owner, + dictionary.all_synonyms.c.synonym_name, + ).where(dictionary.all_synonyms.c.owner.in_(all_remote_owners)) + + result = self._execute_reflection( + connection, query, dblink, returns_long=False + ).mappings() - if row[9] != "NO ACTION": - rec["options"]["ondelete"] = row[9] + remote_owners_lut = {} + for row in result: + synonym_owner = self.normalize_name(row["owner"]) + table_name = self.normalize_name(row["table_name"]) - local_cols.append(local_column) - remote_cols.append(remote_column) + remote_owners_lut[(synonym_owner, table_name)] = ( + row["table_owner"], + row["synonym_name"], + ) + + empty = (None, None) + for table_fkeys in fkeys.values(): + for table_fkey in table_fkeys.values(): + key = ( + table_fkey.pop("_ref_schema"), + table_fkey["referred_table"], + ) + remote_owner, syn_name = remote_owners_lut.get(key, empty) + if syn_name: + sn = self.normalize_name(syn_name) + table_fkey["referred_table"] = sn + if schema is not None or remote_owner != owner: + ro = self.normalize_name(remote_owner) + table_fkey["referred_schema"] = ro + else: + table_fkey["referred_schema"] = None + default = ReflectionDefaults.foreign_keys - return list(fkeys.values()) + return ( + (key, list(fkeys[key].values()) if key in fkeys else default()) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) @reflection.cache def get_unique_constraints( self, connection, table_name, schema=None, **kw ): - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") - - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_unique_constraints( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - constraint_data = self._get_constraint_data( - connection, - table_name, - schema, - dblink, - info_cache=kw.get("info_cache"), + @_handle_synonyms_decorator + def get_multi_unique_constraints( + self, + connection, + *, + scope, + schema, + filter_names, + kind, + dblink=None, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw ) - unique_keys = filter(lambda x: x[1] == "U", constraint_data) - uniques_group = groupby(unique_keys, lambda x: x[0]) + unique_cons = defaultdict(dict) index_names = { - ix["name"] - for ix in self.get_indexes(connection, table_name, schema=schema) + row_dict["index_name"] + for row_dict in self._get_indexes_rows( + connection, schema, dblink, all_objects, **kw + ) } - return [ - { - "name": name, - "column_names": cols, - "duplicates_index": name if name in index_names else None, - } - for name, cols in [ - [ - self.normalize_name(i[0]), - [self.normalize_name(x[2]) for x in i[1]], - ] - for i in uniques_group - ] - ] + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "U": + continue + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name_orig = row_dict["constraint_name"] + constraint_name = self.normalize_name(constraint_name_orig) + column_name = self.normalize_name(row_dict["local_column"]) + table_uc = unique_cons[(schema, table_name)] + + assert constraint_name is not None + + if constraint_name not in table_uc: + table_uc[constraint_name] = uc = { + "name": constraint_name, + "column_names": [], + "duplicates_index": constraint_name + if constraint_name_orig in index_names + else None, + } + else: + uc = table_uc[constraint_name] + + uc["column_names"].append(column_name) + + default = ReflectionDefaults.unique_constraints + + return ( + ( + key, + list(unique_cons[key].values()) + if key in unique_cons + else default(), + ) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) @reflection.cache def get_view_definition( @@ -2553,65 +2997,129 @@ class OracleDialect(default.DefaultDialect): connection, view_name, schema=None, - resolve_synonyms=False, - dblink="", + dblink=None, **kw, ): - info_cache = kw.get("info_cache") - (view_name, schema, dblink, synonym) = self._prepare_reflection_args( - connection, - view_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + if kw.get("oracle_resolve_synonyms", False): + synonyms = self._get_synonyms( + connection, schema, filter_names=[view_name], dblink=dblink + ) + if synonyms: + assert len(synonyms) == 1 + row_dict = synonyms[0] + dblink = self.normalize_name(row_dict["db_link"]) + schema = row_dict["table_owner"] + view_name = row_dict["table_name"] + + name = self.denormalize_name(view_name) + owner = self.denormalize_name(schema or self.default_schema_name) + query = ( + select(dictionary.all_views.c.text) + .where( + dictionary.all_views.c.view_name == name, + dictionary.all_views.c.owner == owner, + ) + .union_all( + select(dictionary.all_mviews.c.query).where( + dictionary.all_mviews.c.mview_name == name, + dictionary.all_mviews.c.owner == owner, + ) + ) ) - params = {"view_name": view_name} - text = "SELECT text FROM all_views WHERE view_name=:view_name" - - if schema is not None: - text += " AND owner = :schema" - params["schema"] = schema - - rp = connection.execute(sql.text(text), params).scalar() - if rp: - return rp + rp = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalar() + if rp is None: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) else: - return None + return rp @reflection.cache def get_check_constraints( self, connection, table_name, schema=None, include_all=False, **kw ): - resolve_synonyms = kw.get("oracle_resolve_synonyms", False) - dblink = kw.get("dblink", "") - info_cache = kw.get("info_cache") - - (table_name, schema, dblink, synonym) = self._prepare_reflection_args( + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + data = self.get_multi_check_constraints( connection, - table_name, - schema, - resolve_synonyms, - dblink, - info_cache=info_cache, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + include_all=include_all, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - constraint_data = self._get_constraint_data( - connection, - table_name, - schema, - dblink, - info_cache=kw.get("info_cache"), + @_handle_synonyms_decorator + def get_multi_check_constraints( + self, + connection, + *, + schema, + filter_names, + dblink=None, + scope, + kind, + include_all=False, + **kw, + ): + """Supported kw arguments are: ``dblink`` to reflect via a db link; + ``oracle_resolve_synonyms`` to resolve names to synonyms + """ + all_objects = self._get_all_objects( + connection, schema, scope, kind, filter_names, dblink, **kw ) - check_constraints = filter(lambda x: x[1] == "C", constraint_data) + not_null = re.compile(r"..+?. IS NOT NULL$") - return [ - {"name": self.normalize_name(cons[0]), "sqltext": cons[8]} - for cons in check_constraints - if include_all or not re.match(r"..+?. IS NOT NULL$", cons[8]) - ] + check_constraints = defaultdict(list) + + for row_dict in self._get_all_constraint_rows( + connection, schema, dblink, all_objects, **kw + ): + if row_dict["constraint_type"] != "C": + continue + table_name = self.normalize_name(row_dict["table_name"]) + constraint_name = self.normalize_name(row_dict["constraint_name"]) + search_condition = row_dict["search_condition"] + + table_checks = check_constraints[(schema, table_name)] + if constraint_name is not None and ( + include_all or not not_null.match(search_condition) + ): + table_checks.append( + {"name": constraint_name, "sqltext": search_condition} + ) + + default = ReflectionDefaults.check_constraints + + return ( + ( + key, + check_constraints[key] + if key in check_constraints + else default(), + ) + for key in ( + (schema, self.normalize_name(obj_name)) + for obj_name in all_objects + ) + ) + + def _list_dblinks(self, connection, dblink=None): + query = select(dictionary.all_db_links.c.db_link) + links = self._execute_reflection( + connection, query, dblink, returns_long=False + ).scalars() + return [self.normalize_name(link) for link in links] class _OuterJoinColumn(sql.ClauseElement): diff --git a/lib/sqlalchemy/dialects/oracle/cx_oracle.py b/lib/sqlalchemy/dialects/oracle/cx_oracle.py index 25e93632c..d2ee0a96e 100644 --- a/lib/sqlalchemy/dialects/oracle/cx_oracle.py +++ b/lib/sqlalchemy/dialects/oracle/cx_oracle.py @@ -431,6 +431,7 @@ from . import base as oracle from .base import OracleCompiler from .base import OracleDialect from .base import OracleExecutionContext +from .types import _OracleDateLiteralRender from ... import exc from ... import util from ...engine import cursor as _cursor @@ -573,7 +574,7 @@ class _CXOracleDate(oracle._OracleDate): return process -class _CXOracleTIMESTAMP(oracle._OracleDateLiteralRender, sqltypes.TIMESTAMP): +class _CXOracleTIMESTAMP(_OracleDateLiteralRender, sqltypes.TIMESTAMP): def literal_processor(self, dialect): return self._literal_processor_datetime(dialect) @@ -812,6 +813,7 @@ class OracleExecutionContext_cx_oracle(OracleExecutionContext): return None def pre_exec(self): + super().pre_exec() if not getattr(self.compiled, "_oracle_cx_sql_compiler", False): return diff --git a/lib/sqlalchemy/dialects/oracle/dictionary.py b/lib/sqlalchemy/dialects/oracle/dictionary.py new file mode 100644 index 000000000..ac7a350da --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/dictionary.py @@ -0,0 +1,495 @@ +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .types import DATE +from .types import LONG +from .types import NUMBER +from .types import RAW +from .types import VARCHAR2 +from ... import Column +from ... import MetaData +from ... import Table +from ... import table +from ...sql.sqltypes import CHAR + +# constants +DB_LINK_PLACEHOLDER = "__$sa_dblink$__" +# tables +dual = table("dual") +dictionary_meta = MetaData() + +# NOTE: all the dictionary_meta are aliases because oracle does not like +# using the full table@dblink for every column in query, and complains with +# ORA-00960: ambiguous column naming in select list +all_tables = Table( + "all_tables" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("tablespace_name", VARCHAR2(30)), + Column("cluster_name", VARCHAR2(128)), + Column("iot_name", VARCHAR2(128)), + Column("status", VARCHAR2(8)), + Column("pct_free", NUMBER), + Column("pct_used", NUMBER), + Column("ini_trans", NUMBER), + Column("max_trans", NUMBER), + Column("initial_extent", NUMBER), + Column("next_extent", NUMBER), + Column("min_extents", NUMBER), + Column("max_extents", NUMBER), + Column("pct_increase", NUMBER), + Column("freelists", NUMBER), + Column("freelist_groups", NUMBER), + Column("logging", VARCHAR2(3)), + Column("backed_up", VARCHAR2(1)), + Column("num_rows", NUMBER), + Column("blocks", NUMBER), + Column("empty_blocks", NUMBER), + Column("avg_space", NUMBER), + Column("chain_cnt", NUMBER), + Column("avg_row_len", NUMBER), + Column("avg_space_freelist_blocks", NUMBER), + Column("num_freelist_blocks", NUMBER), + Column("degree", VARCHAR2(10)), + Column("instances", VARCHAR2(10)), + Column("cache", VARCHAR2(5)), + Column("table_lock", VARCHAR2(8)), + Column("sample_size", NUMBER), + Column("last_analyzed", DATE), + Column("partitioned", VARCHAR2(3)), + Column("iot_type", VARCHAR2(12)), + Column("temporary", VARCHAR2(1)), + Column("secondary", VARCHAR2(1)), + Column("nested", VARCHAR2(3)), + Column("buffer_pool", VARCHAR2(7)), + Column("flash_cache", VARCHAR2(7)), + Column("cell_flash_cache", VARCHAR2(7)), + Column("row_movement", VARCHAR2(8)), + Column("global_stats", VARCHAR2(3)), + Column("user_stats", VARCHAR2(3)), + Column("duration", VARCHAR2(15)), + Column("skip_corrupt", VARCHAR2(8)), + Column("monitoring", VARCHAR2(3)), + Column("cluster_owner", VARCHAR2(128)), + Column("dependencies", VARCHAR2(8)), + Column("compression", VARCHAR2(8)), + Column("compress_for", VARCHAR2(30)), + Column("dropped", VARCHAR2(3)), + Column("read_only", VARCHAR2(3)), + Column("segment_created", VARCHAR2(3)), + Column("result_cache", VARCHAR2(7)), + Column("clustering", VARCHAR2(3)), + Column("activity_tracking", VARCHAR2(23)), + Column("dml_timestamp", VARCHAR2(25)), + Column("has_identity", VARCHAR2(3)), + Column("container_data", VARCHAR2(3)), + Column("inmemory", VARCHAR2(8)), + Column("inmemory_priority", VARCHAR2(8)), + Column("inmemory_distribute", VARCHAR2(15)), + Column("inmemory_compression", VARCHAR2(17)), + Column("inmemory_duplicate", VARCHAR2(13)), + Column("default_collation", VARCHAR2(100)), + Column("duplicated", VARCHAR2(1)), + Column("sharded", VARCHAR2(1)), + Column("externally_sharded", VARCHAR2(1)), + Column("externally_duplicated", VARCHAR2(1)), + Column("external", VARCHAR2(3)), + Column("hybrid", VARCHAR2(3)), + Column("cellmemory", VARCHAR2(24)), + Column("containers_default", VARCHAR2(3)), + Column("container_map", VARCHAR2(3)), + Column("extended_data_link", VARCHAR2(3)), + Column("extended_data_link_map", VARCHAR2(3)), + Column("inmemory_service", VARCHAR2(12)), + Column("inmemory_service_name", VARCHAR2(1000)), + Column("container_map_object", VARCHAR2(3)), + Column("memoptimize_read", VARCHAR2(8)), + Column("memoptimize_write", VARCHAR2(8)), + Column("has_sensitive_column", VARCHAR2(3)), + Column("admit_null", VARCHAR2(3)), + Column("data_link_dml_enabled", VARCHAR2(3)), + Column("logical_replication", VARCHAR2(8)), +).alias("a_tables") + +all_views = Table( + "all_views" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("view_name", VARCHAR2(128), nullable=False), + Column("text_length", NUMBER), + Column("text", LONG), + Column("text_vc", VARCHAR2(4000)), + Column("type_text_length", NUMBER), + Column("type_text", VARCHAR2(4000)), + Column("oid_text_length", NUMBER), + Column("oid_text", VARCHAR2(4000)), + Column("view_type_owner", VARCHAR2(128)), + Column("view_type", VARCHAR2(128)), + Column("superview_name", VARCHAR2(128)), + Column("editioning_view", VARCHAR2(1)), + Column("read_only", VARCHAR2(1)), + Column("container_data", VARCHAR2(1)), + Column("bequeath", VARCHAR2(12)), + Column("origin_con_id", VARCHAR2(256)), + Column("default_collation", VARCHAR2(100)), + Column("containers_default", VARCHAR2(3)), + Column("container_map", VARCHAR2(3)), + Column("extended_data_link", VARCHAR2(3)), + Column("extended_data_link_map", VARCHAR2(3)), + Column("has_sensitive_column", VARCHAR2(3)), + Column("admit_null", VARCHAR2(3)), + Column("pdb_local_only", VARCHAR2(3)), +).alias("a_views") + +all_sequences = Table( + "all_sequences" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("sequence_owner", VARCHAR2(128), nullable=False), + Column("sequence_name", VARCHAR2(128), nullable=False), + Column("min_value", NUMBER), + Column("max_value", NUMBER), + Column("increment_by", NUMBER, nullable=False), + Column("cycle_flag", VARCHAR2(1)), + Column("order_flag", VARCHAR2(1)), + Column("cache_size", NUMBER, nullable=False), + Column("last_number", NUMBER, nullable=False), + Column("scale_flag", VARCHAR2(1)), + Column("extend_flag", VARCHAR2(1)), + Column("sharded_flag", VARCHAR2(1)), + Column("session_flag", VARCHAR2(1)), + Column("keep_value", VARCHAR2(1)), +).alias("a_sequences") + +all_users = Table( + "all_users" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("username", VARCHAR2(128), nullable=False), + Column("user_id", NUMBER, nullable=False), + Column("created", DATE, nullable=False), + Column("common", VARCHAR2(3)), + Column("oracle_maintained", VARCHAR2(1)), + Column("inherited", VARCHAR2(3)), + Column("default_collation", VARCHAR2(100)), + Column("implicit", VARCHAR2(3)), + Column("all_shard", VARCHAR2(3)), + Column("external_shard", VARCHAR2(3)), +).alias("a_users") + +all_mviews = Table( + "all_mviews" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("mview_name", VARCHAR2(128), nullable=False), + Column("container_name", VARCHAR2(128), nullable=False), + Column("query", LONG), + Column("query_len", NUMBER(38)), + Column("updatable", VARCHAR2(1)), + Column("update_log", VARCHAR2(128)), + Column("master_rollback_seg", VARCHAR2(128)), + Column("master_link", VARCHAR2(128)), + Column("rewrite_enabled", VARCHAR2(1)), + Column("rewrite_capability", VARCHAR2(9)), + Column("refresh_mode", VARCHAR2(6)), + Column("refresh_method", VARCHAR2(8)), + Column("build_mode", VARCHAR2(9)), + Column("fast_refreshable", VARCHAR2(18)), + Column("last_refresh_type", VARCHAR2(8)), + Column("last_refresh_date", DATE), + Column("last_refresh_end_time", DATE), + Column("staleness", VARCHAR2(19)), + Column("after_fast_refresh", VARCHAR2(19)), + Column("unknown_prebuilt", VARCHAR2(1)), + Column("unknown_plsql_func", VARCHAR2(1)), + Column("unknown_external_table", VARCHAR2(1)), + Column("unknown_consider_fresh", VARCHAR2(1)), + Column("unknown_import", VARCHAR2(1)), + Column("unknown_trusted_fd", VARCHAR2(1)), + Column("compile_state", VARCHAR2(19)), + Column("use_no_index", VARCHAR2(1)), + Column("stale_since", DATE), + Column("num_pct_tables", NUMBER), + Column("num_fresh_pct_regions", NUMBER), + Column("num_stale_pct_regions", NUMBER), + Column("segment_created", VARCHAR2(3)), + Column("evaluation_edition", VARCHAR2(128)), + Column("unusable_before", VARCHAR2(128)), + Column("unusable_beginning", VARCHAR2(128)), + Column("default_collation", VARCHAR2(100)), + Column("on_query_computation", VARCHAR2(1)), + Column("auto", VARCHAR2(3)), +).alias("a_mviews") + +all_tab_identity_cols = Table( + "all_tab_identity_cols" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(128), nullable=False), + Column("generation_type", VARCHAR2(10)), + Column("sequence_name", VARCHAR2(128), nullable=False), + Column("identity_options", VARCHAR2(298)), +).alias("a_tab_identity_cols") + +all_tab_cols = Table( + "all_tab_cols" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(128), nullable=False), + Column("data_type", VARCHAR2(128)), + Column("data_type_mod", VARCHAR2(3)), + Column("data_type_owner", VARCHAR2(128)), + Column("data_length", NUMBER, nullable=False), + Column("data_precision", NUMBER), + Column("data_scale", NUMBER), + Column("nullable", VARCHAR2(1)), + Column("column_id", NUMBER), + Column("default_length", NUMBER), + Column("data_default", LONG), + Column("num_distinct", NUMBER), + Column("low_value", RAW(1000)), + Column("high_value", RAW(1000)), + Column("density", NUMBER), + Column("num_nulls", NUMBER), + Column("num_buckets", NUMBER), + Column("last_analyzed", DATE), + Column("sample_size", NUMBER), + Column("character_set_name", VARCHAR2(44)), + Column("char_col_decl_length", NUMBER), + Column("global_stats", VARCHAR2(3)), + Column("user_stats", VARCHAR2(3)), + Column("avg_col_len", NUMBER), + Column("char_length", NUMBER), + Column("char_used", VARCHAR2(1)), + Column("v80_fmt_image", VARCHAR2(3)), + Column("data_upgraded", VARCHAR2(3)), + Column("hidden_column", VARCHAR2(3)), + Column("virtual_column", VARCHAR2(3)), + Column("segment_column_id", NUMBER), + Column("internal_column_id", NUMBER, nullable=False), + Column("histogram", VARCHAR2(15)), + Column("qualified_col_name", VARCHAR2(4000)), + Column("user_generated", VARCHAR2(3)), + Column("default_on_null", VARCHAR2(3)), + Column("identity_column", VARCHAR2(3)), + Column("evaluation_edition", VARCHAR2(128)), + Column("unusable_before", VARCHAR2(128)), + Column("unusable_beginning", VARCHAR2(128)), + Column("collation", VARCHAR2(100)), + Column("collated_column_id", NUMBER), +).alias("a_tab_cols") + +all_tab_comments = Table( + "all_tab_comments" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("table_type", VARCHAR2(11)), + Column("comments", VARCHAR2(4000)), + Column("origin_con_id", NUMBER), +).alias("a_tab_comments") + +all_col_comments = Table( + "all_col_comments" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(128), nullable=False), + Column("comments", VARCHAR2(4000)), + Column("origin_con_id", NUMBER), +).alias("a_col_comments") + +all_mview_comments = Table( + "all_mview_comments" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("mview_name", VARCHAR2(128), nullable=False), + Column("comments", VARCHAR2(4000)), +).alias("a_mview_comments") + +all_ind_columns = Table( + "all_ind_columns" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("index_owner", VARCHAR2(128), nullable=False), + Column("index_name", VARCHAR2(128), nullable=False), + Column("table_owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(4000)), + Column("column_position", NUMBER, nullable=False), + Column("column_length", NUMBER, nullable=False), + Column("char_length", NUMBER), + Column("descend", VARCHAR2(4)), + Column("collated_column_id", NUMBER), +).alias("a_ind_columns") + +all_indexes = Table( + "all_indexes" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("index_name", VARCHAR2(128), nullable=False), + Column("index_type", VARCHAR2(27)), + Column("table_owner", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("table_type", CHAR(11)), + Column("uniqueness", VARCHAR2(9)), + Column("compression", VARCHAR2(13)), + Column("prefix_length", NUMBER), + Column("tablespace_name", VARCHAR2(30)), + Column("ini_trans", NUMBER), + Column("max_trans", NUMBER), + Column("initial_extent", NUMBER), + Column("next_extent", NUMBER), + Column("min_extents", NUMBER), + Column("max_extents", NUMBER), + Column("pct_increase", NUMBER), + Column("pct_threshold", NUMBER), + Column("include_column", NUMBER), + Column("freelists", NUMBER), + Column("freelist_groups", NUMBER), + Column("pct_free", NUMBER), + Column("logging", VARCHAR2(3)), + Column("blevel", NUMBER), + Column("leaf_blocks", NUMBER), + Column("distinct_keys", NUMBER), + Column("avg_leaf_blocks_per_key", NUMBER), + Column("avg_data_blocks_per_key", NUMBER), + Column("clustering_factor", NUMBER), + Column("status", VARCHAR2(8)), + Column("num_rows", NUMBER), + Column("sample_size", NUMBER), + Column("last_analyzed", DATE), + Column("degree", VARCHAR2(40)), + Column("instances", VARCHAR2(40)), + Column("partitioned", VARCHAR2(3)), + Column("temporary", VARCHAR2(1)), + Column("generated", VARCHAR2(1)), + Column("secondary", VARCHAR2(1)), + Column("buffer_pool", VARCHAR2(7)), + Column("flash_cache", VARCHAR2(7)), + Column("cell_flash_cache", VARCHAR2(7)), + Column("user_stats", VARCHAR2(3)), + Column("duration", VARCHAR2(15)), + Column("pct_direct_access", NUMBER), + Column("ityp_owner", VARCHAR2(128)), + Column("ityp_name", VARCHAR2(128)), + Column("parameters", VARCHAR2(1000)), + Column("global_stats", VARCHAR2(3)), + Column("domidx_status", VARCHAR2(12)), + Column("domidx_opstatus", VARCHAR2(6)), + Column("funcidx_status", VARCHAR2(8)), + Column("join_index", VARCHAR2(3)), + Column("iot_redundant_pkey_elim", VARCHAR2(3)), + Column("dropped", VARCHAR2(3)), + Column("visibility", VARCHAR2(9)), + Column("domidx_management", VARCHAR2(14)), + Column("segment_created", VARCHAR2(3)), + Column("orphaned_entries", VARCHAR2(3)), + Column("indexing", VARCHAR2(7)), + Column("auto", VARCHAR2(3)), +).alias("a_indexes") + +all_constraints = Table( + "all_constraints" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128)), + Column("constraint_name", VARCHAR2(128)), + Column("constraint_type", VARCHAR2(1)), + Column("table_name", VARCHAR2(128)), + Column("search_condition", LONG), + Column("search_condition_vc", VARCHAR2(4000)), + Column("r_owner", VARCHAR2(128)), + Column("r_constraint_name", VARCHAR2(128)), + Column("delete_rule", VARCHAR2(9)), + Column("status", VARCHAR2(8)), + Column("deferrable", VARCHAR2(14)), + Column("deferred", VARCHAR2(9)), + Column("validated", VARCHAR2(13)), + Column("generated", VARCHAR2(14)), + Column("bad", VARCHAR2(3)), + Column("rely", VARCHAR2(4)), + Column("last_change", DATE), + Column("index_owner", VARCHAR2(128)), + Column("index_name", VARCHAR2(128)), + Column("invalid", VARCHAR2(7)), + Column("view_related", VARCHAR2(14)), + Column("origin_con_id", VARCHAR2(256)), +).alias("a_constraints") + +all_cons_columns = Table( + "all_cons_columns" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("constraint_name", VARCHAR2(128), nullable=False), + Column("table_name", VARCHAR2(128), nullable=False), + Column("column_name", VARCHAR2(4000)), + Column("position", NUMBER), +).alias("a_cons_columns") + +# TODO figure out if it's still relevant, since there is no mention from here +# https://docs.oracle.com/en/database/oracle/oracle-database/21/refrn/ALL_DB_LINKS.html +# original note: +# using user_db_links here since all_db_links appears +# to have more restricted permissions. +# https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm +# will need to hear from more users if we are doing +# the right thing here. See [ticket:2619] +all_db_links = Table( + "all_db_links" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("db_link", VARCHAR2(128), nullable=False), + Column("username", VARCHAR2(128)), + Column("host", VARCHAR2(2000)), + Column("created", DATE, nullable=False), + Column("hidden", VARCHAR2(3)), + Column("shard_internal", VARCHAR2(3)), + Column("valid", VARCHAR2(3)), + Column("intra_cdb", VARCHAR2(3)), +).alias("a_db_links") + +all_synonyms = Table( + "all_synonyms" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128)), + Column("synonym_name", VARCHAR2(128)), + Column("table_owner", VARCHAR2(128)), + Column("table_name", VARCHAR2(128)), + Column("db_link", VARCHAR2(128)), + Column("origin_con_id", VARCHAR2(256)), +).alias("a_synonyms") + +all_objects = Table( + "all_objects" + DB_LINK_PLACEHOLDER, + dictionary_meta, + Column("owner", VARCHAR2(128), nullable=False), + Column("object_name", VARCHAR2(128), nullable=False), + Column("subobject_name", VARCHAR2(128)), + Column("object_id", NUMBER, nullable=False), + Column("data_object_id", NUMBER), + Column("object_type", VARCHAR2(23)), + Column("created", DATE, nullable=False), + Column("last_ddl_time", DATE, nullable=False), + Column("timestamp", VARCHAR2(19)), + Column("status", VARCHAR2(7)), + Column("temporary", VARCHAR2(1)), + Column("generated", VARCHAR2(1)), + Column("secondary", VARCHAR2(1)), + Column("namespace", NUMBER, nullable=False), + Column("edition_name", VARCHAR2(128)), + Column("sharing", VARCHAR2(13)), + Column("editionable", VARCHAR2(1)), + Column("oracle_maintained", VARCHAR2(1)), + Column("application", VARCHAR2(1)), + Column("default_collation", VARCHAR2(100)), + Column("duplicated", VARCHAR2(1)), + Column("sharded", VARCHAR2(1)), + Column("created_appid", NUMBER), + Column("created_vsnid", NUMBER), + Column("modified_appid", NUMBER), + Column("modified_vsnid", NUMBER), +).alias("a_objects") diff --git a/lib/sqlalchemy/dialects/oracle/provision.py b/lib/sqlalchemy/dialects/oracle/provision.py index cba3b5be4..75b7a7aa9 100644 --- a/lib/sqlalchemy/dialects/oracle/provision.py +++ b/lib/sqlalchemy/dialects/oracle/provision.py @@ -2,9 +2,12 @@ from ... import create_engine from ... import exc +from ... import inspect from ...engine import url as sa_url from ...testing.provision import configure_follower from ...testing.provision import create_db +from ...testing.provision import drop_all_schema_objects_post_tables +from ...testing.provision import drop_all_schema_objects_pre_tables from ...testing.provision import drop_db from ...testing.provision import follower_url_from_main from ...testing.provision import log @@ -28,6 +31,10 @@ def _oracle_create_db(cfg, eng, ident): conn.exec_driver_sql("grant unlimited tablespace to %s" % ident) conn.exec_driver_sql("grant unlimited tablespace to %s_ts1" % ident) conn.exec_driver_sql("grant unlimited tablespace to %s_ts2" % ident) + # these are needed to create materialized views + conn.exec_driver_sql("grant create table to %s" % ident) + conn.exec_driver_sql("grant create table to %s_ts1" % ident) + conn.exec_driver_sql("grant create table to %s_ts2" % ident) @configure_follower.for_db("oracle") @@ -46,6 +53,30 @@ def _ora_drop_ignore(conn, dbname): return False +@drop_all_schema_objects_pre_tables.for_db("oracle") +def _ora_drop_all_schema_objects_pre_tables(cfg, eng): + _purge_recyclebin(eng) + _purge_recyclebin(eng, cfg.test_schema) + + +@drop_all_schema_objects_post_tables.for_db("oracle") +def _ora_drop_all_schema_objects_post_tables(cfg, eng): + + with eng.begin() as conn: + for syn in conn.dialect._get_synonyms(conn, None, None, None): + conn.exec_driver_sql(f"drop synonym {syn['synonym_name']}") + + for syn in conn.dialect._get_synonyms( + conn, cfg.test_schema, None, None + ): + conn.exec_driver_sql( + f"drop synonym {cfg.test_schema}.{syn['synonym_name']}" + ) + + for tmp_table in inspect(conn).get_temp_table_names(): + conn.exec_driver_sql(f"drop table {tmp_table}") + + @drop_db.for_db("oracle") def _oracle_drop_db(cfg, eng, ident): with eng.begin() as conn: @@ -60,13 +91,10 @@ def _oracle_drop_db(cfg, eng, ident): @stop_test_class_outside_fixtures.for_db("oracle") -def stop_test_class_outside_fixtures(config, db, cls): +def _ora_stop_test_class_outside_fixtures(config, db, cls): try: - with db.begin() as conn: - # run magic command to get rid of identity sequences - # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501 - conn.exec_driver_sql("purge recyclebin") + _purge_recyclebin(db) except exc.DatabaseError as err: log.warning("purge recyclebin command failed: %s", err) @@ -85,6 +113,22 @@ def stop_test_class_outside_fixtures(config, db, cls): _all_conns.clear() +def _purge_recyclebin(eng, schema=None): + with eng.begin() as conn: + if schema is None: + # run magic command to get rid of identity sequences + # https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501 + conn.exec_driver_sql("purge recyclebin") + else: + # per user: https://community.oracle.com/tech/developers/discussion/2255402/how-to-clear-dba-recyclebin-for-a-particular-user # noqa: E501 + for owner, object_name, type_ in conn.exec_driver_sql( + "select owner, object_name,type from " + "dba_recyclebin where owner=:schema and type='TABLE'", + {"schema": conn.dialect.denormalize_name(schema)}, + ).all(): + conn.exec_driver_sql(f'purge {type_} {owner}."{object_name}"') + + _all_conns = set() diff --git a/lib/sqlalchemy/dialects/oracle/types.py b/lib/sqlalchemy/dialects/oracle/types.py new file mode 100644 index 000000000..60a8ebcb5 --- /dev/null +++ b/lib/sqlalchemy/dialects/oracle/types.py @@ -0,0 +1,233 @@ +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from ...sql import sqltypes +from ...types import NVARCHAR +from ...types import VARCHAR + + +class RAW(sqltypes._Binary): + __visit_name__ = "RAW" + + +OracleRaw = RAW + + +class NCLOB(sqltypes.Text): + __visit_name__ = "NCLOB" + + +class VARCHAR2(VARCHAR): + __visit_name__ = "VARCHAR2" + + +NVARCHAR2 = NVARCHAR + + +class NUMBER(sqltypes.Numeric, sqltypes.Integer): + __visit_name__ = "NUMBER" + + def __init__(self, precision=None, scale=None, asdecimal=None): + if asdecimal is None: + asdecimal = bool(scale and scale > 0) + + super(NUMBER, self).__init__( + precision=precision, scale=scale, asdecimal=asdecimal + ) + + def adapt(self, impltype): + ret = super(NUMBER, self).adapt(impltype) + # leave a hint for the DBAPI handler + ret._is_oracle_number = True + return ret + + @property + def _type_affinity(self): + if bool(self.scale and self.scale > 0): + return sqltypes.Numeric + else: + return sqltypes.Integer + + +class FLOAT(sqltypes.FLOAT): + """Oracle FLOAT. + + This is the same as :class:`_sqltypes.FLOAT` except that + an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision` + parameter is accepted, and + the :paramref:`_sqltypes.Float.precision` parameter is not accepted. + + Oracle FLOAT types indicate precision in terms of "binary precision", which + defaults to 126. For a REAL type, the value is 63. This parameter does not + cleanly map to a specific number of decimal places but is roughly + equivalent to the desired number of decimal places divided by 0.3103. + + .. versionadded:: 2.0 + + """ + + __visit_name__ = "FLOAT" + + def __init__( + self, + binary_precision=None, + asdecimal=False, + decimal_return_scale=None, + ): + r""" + Construct a FLOAT + + :param binary_precision: Oracle binary precision value to be rendered + in DDL. This may be approximated to the number of decimal characters + using the formula "decimal precision = 0.30103 * binary precision". + The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126. + + :param asdecimal: See :paramref:`_sqltypes.Float.asdecimal` + + :param decimal_return_scale: See + :paramref:`_sqltypes.Float.decimal_return_scale` + + """ + super().__init__( + asdecimal=asdecimal, decimal_return_scale=decimal_return_scale + ) + self.binary_precision = binary_precision + + +class BINARY_DOUBLE(sqltypes.Float): + __visit_name__ = "BINARY_DOUBLE" + + +class BINARY_FLOAT(sqltypes.Float): + __visit_name__ = "BINARY_FLOAT" + + +class BFILE(sqltypes.LargeBinary): + __visit_name__ = "BFILE" + + +class LONG(sqltypes.Text): + __visit_name__ = "LONG" + + +class _OracleDateLiteralRender: + def _literal_processor_datetime(self, dialect): + def process(value): + if value is not None: + if getattr(value, "microsecond", None): + value = ( + f"""TO_TIMESTAMP""" + f"""('{value.isoformat().replace("T", " ")}', """ + """'YYYY-MM-DD HH24:MI:SS.FF')""" + ) + else: + value = ( + f"""TO_DATE""" + f"""('{value.isoformat().replace("T", " ")}', """ + """'YYYY-MM-DD HH24:MI:SS')""" + ) + return value + + return process + + def _literal_processor_date(self, dialect): + def process(value): + if value is not None: + if getattr(value, "microsecond", None): + value = ( + f"""TO_TIMESTAMP""" + f"""('{value.isoformat().split("T")[0]}', """ + """'YYYY-MM-DD')""" + ) + else: + value = ( + f"""TO_DATE""" + f"""('{value.isoformat().split("T")[0]}', """ + """'YYYY-MM-DD')""" + ) + return value + + return process + + +class DATE(_OracleDateLiteralRender, sqltypes.DateTime): + """Provide the oracle DATE type. + + This type has no special Python behavior, except that it subclasses + :class:`_types.DateTime`; this is to suit the fact that the Oracle + ``DATE`` type supports a time value. + + .. versionadded:: 0.9.4 + + """ + + __visit_name__ = "DATE" + + def literal_processor(self, dialect): + return self._literal_processor_datetime(dialect) + + def _compare_type_affinity(self, other): + return other._type_affinity in (sqltypes.DateTime, sqltypes.Date) + + +class _OracleDate(_OracleDateLiteralRender, sqltypes.Date): + def literal_processor(self, dialect): + return self._literal_processor_date(dialect) + + +class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): + __visit_name__ = "INTERVAL" + + def __init__(self, day_precision=None, second_precision=None): + """Construct an INTERVAL. + + Note that only DAY TO SECOND intervals are currently supported. + This is due to a lack of support for YEAR TO MONTH intervals + within available DBAPIs. + + :param day_precision: the day precision value. this is the number of + digits to store for the day field. Defaults to "2" + :param second_precision: the second precision value. this is the + number of digits to store for the fractional seconds field. + Defaults to "6". + + """ + self.day_precision = day_precision + self.second_precision = second_precision + + @classmethod + def _adapt_from_generic_interval(cls, interval): + return INTERVAL( + day_precision=interval.day_precision, + second_precision=interval.second_precision, + ) + + @property + def _type_affinity(self): + return sqltypes.Interval + + def as_generic(self, allow_nulltype=False): + return sqltypes.Interval( + native=True, + second_precision=self.second_precision, + day_precision=self.day_precision, + ) + + +class ROWID(sqltypes.TypeEngine): + """Oracle ROWID type. + + When used in a cast() or similar, generates ROWID. + + """ + + __visit_name__ = "ROWID" + + +class _OracleBoolean(sqltypes.Boolean): + def get_dbapi_type(self, dbapi): + return dbapi.NUMBER diff --git a/lib/sqlalchemy/dialects/postgresql/__init__.py b/lib/sqlalchemy/dialects/postgresql/__init__.py index c2472fb55..85bbf8c5b 100644 --- a/lib/sqlalchemy/dialects/postgresql/__init__.py +++ b/lib/sqlalchemy/dialects/postgresql/__init__.py @@ -19,31 +19,16 @@ from .array import Any from .array import ARRAY from .array import array from .base import BIGINT -from .base import BIT from .base import BOOLEAN -from .base import BYTEA from .base import CHAR -from .base import CIDR -from .base import CreateEnumType from .base import DATE from .base import DOUBLE_PRECISION -from .base import DropEnumType -from .base import ENUM from .base import FLOAT -from .base import INET from .base import INTEGER -from .base import INTERVAL -from .base import MACADDR -from .base import MONEY from .base import NUMERIC -from .base import OID from .base import REAL -from .base import REGCLASS from .base import SMALLINT from .base import TEXT -from .base import TIME -from .base import TIMESTAMP -from .base import TSVECTOR from .base import UUID from .base import VARCHAR from .dml import Insert @@ -61,7 +46,21 @@ from .ranges import INT8RANGE from .ranges import NUMRANGE from .ranges import TSRANGE from .ranges import TSTZRANGE -from ...util import compat +from .types import BIT +from .types import BYTEA +from .types import CIDR +from .types import CreateEnumType +from .types import DropEnumType +from .types import ENUM +from .types import INET +from .types import INTERVAL +from .types import MACADDR +from .types import MONEY +from .types import OID +from .types import REGCLASS +from .types import TIME +from .types import TIMESTAMP +from .types import TSVECTOR # Alias psycopg also as psycopg_async psycopg_async = type( diff --git a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py index e831f2ed9..8dcd36c6d 100644 --- a/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py +++ b/lib/sqlalchemy/dialects/postgresql/_psycopg_common.py @@ -1,3 +1,8 @@ +# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php # mypy: ignore-errors import decimal @@ -9,6 +14,9 @@ from .base import _INT_TYPES from .base import PGDialect from .base import PGExecutionContext from .hstore import HSTORE +from .pg_catalog import _SpaceVector +from .pg_catalog import INT2VECTOR +from .pg_catalog import OIDVECTOR from ... import exc from ... import types as sqltypes from ... import util @@ -66,6 +74,14 @@ class _PsycopgARRAY(PGARRAY): render_bind_cast = True +class _PsycopgINT2VECTOR(_SpaceVector, INT2VECTOR): + pass + + +class _PsycopgOIDVECTOR(_SpaceVector, OIDVECTOR): + pass + + class _PGExecutionContext_common_psycopg(PGExecutionContext): def create_server_side_cursor(self): # use server-side cursors: @@ -91,6 +107,8 @@ class _PGDialect_common_psycopg(PGDialect): sqltypes.Numeric: _PsycopgNumeric, HSTORE: _PsycopgHStore, sqltypes.ARRAY: _PsycopgARRAY, + INT2VECTOR: _PsycopgINT2VECTOR, + OIDVECTOR: _PsycopgOIDVECTOR, }, ) diff --git a/lib/sqlalchemy/dialects/postgresql/asyncpg.py b/lib/sqlalchemy/dialects/postgresql/asyncpg.py index 1ec787e1f..d6385a5d6 100644 --- a/lib/sqlalchemy/dialects/postgresql/asyncpg.py +++ b/lib/sqlalchemy/dialects/postgresql/asyncpg.py @@ -274,6 +274,10 @@ class AsyncpgOID(OID): render_bind_cast = True +class AsyncpgCHAR(sqltypes.CHAR): + render_bind_cast = True + + class PGExecutionContext_asyncpg(PGExecutionContext): def handle_dbapi_exception(self, e): if isinstance( @@ -823,6 +827,7 @@ class PGDialect_asyncpg(PGDialect): sqltypes.Enum: AsyncPgEnum, OID: AsyncpgOID, REGCLASS: AsyncpgREGCLASS, + sqltypes.CHAR: AsyncpgCHAR, }, ) is_async = True diff --git a/lib/sqlalchemy/dialects/postgresql/base.py b/lib/sqlalchemy/dialects/postgresql/base.py index 83e46151f..36de76e0d 100644 --- a/lib/sqlalchemy/dialects/postgresql/base.py +++ b/lib/sqlalchemy/dialects/postgresql/base.py @@ -11,7 +11,7 @@ r""" :name: PostgreSQL :full_support: 9.6, 10, 11, 12, 13, 14 :normal_support: 9.6+ - :best_effort: 8+ + :best_effort: 9+ .. _postgresql_sequences: @@ -1448,23 +1448,52 @@ E.g.:: from __future__ import annotations from collections import defaultdict -import datetime as dt +from functools import lru_cache import re -from typing import Any from . import array as _array from . import dml from . import hstore as _hstore from . import json as _json +from . import pg_catalog from . import ranges as _ranges +from .types import _DECIMAL_TYPES # noqa +from .types import _FLOAT_TYPES # noqa +from .types import _INT_TYPES # noqa +from .types import BIT +from .types import BYTEA +from .types import CIDR +from .types import CreateEnumType # noqa +from .types import DropEnumType # noqa +from .types import ENUM +from .types import INET +from .types import INTERVAL +from .types import MACADDR +from .types import MONEY +from .types import OID +from .types import PGBit # noqa +from .types import PGCidr # noqa +from .types import PGInet # noqa +from .types import PGInterval # noqa +from .types import PGMacAddr # noqa +from .types import PGUuid +from .types import REGCLASS +from .types import TIME +from .types import TIMESTAMP +from .types import TSVECTOR from ... import exc from ... import schema +from ... import select from ... import sql from ... import util from ...engine import characteristics from ...engine import default from ...engine import interfaces +from ...engine import ObjectKind +from ...engine import ObjectScope from ...engine import reflection +from ...engine.reflection import ReflectionDefaults +from ...sql import bindparam from ...sql import coercions from ...sql import compiler from ...sql import elements @@ -1472,7 +1501,7 @@ from ...sql import expression from ...sql import roles from ...sql import sqltypes from ...sql import util as sql_util -from ...sql.ddl import InvokeDDLBase +from ...sql.visitors import InternalTraversal from ...types import BIGINT from ...types import BOOLEAN from ...types import CHAR @@ -1596,469 +1625,6 @@ RESERVED_WORDS = set( ] ) -_DECIMAL_TYPES = (1231, 1700) -_FLOAT_TYPES = (700, 701, 1021, 1022) -_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) - - -class PGUuid(UUID): - render_bind_cast = True - render_literal_cast = True - - -class BYTEA(sqltypes.LargeBinary[bytes]): - __visit_name__ = "BYTEA" - - -class INET(sqltypes.TypeEngine[str]): - __visit_name__ = "INET" - - -PGInet = INET - - -class CIDR(sqltypes.TypeEngine[str]): - __visit_name__ = "CIDR" - - -PGCidr = CIDR - - -class MACADDR(sqltypes.TypeEngine[str]): - __visit_name__ = "MACADDR" - - -PGMacAddr = MACADDR - - -class MONEY(sqltypes.TypeEngine[str]): - - r"""Provide the PostgreSQL MONEY type. - - Depending on driver, result rows using this type may return a - string value which includes currency symbols. - - For this reason, it may be preferable to provide conversion to a - numerically-based currency datatype using :class:`_types.TypeDecorator`:: - - import re - import decimal - from sqlalchemy import TypeDecorator - - class NumericMoney(TypeDecorator): - impl = MONEY - - def process_result_value(self, value: Any, dialect: Any) -> None: - if value is not None: - # adjust this for the currency and numeric - m = re.match(r"\$([\d.]+)", value) - if m: - value = decimal.Decimal(m.group(1)) - return value - - Alternatively, the conversion may be applied as a CAST using - the :meth:`_types.TypeDecorator.column_expression` method as follows:: - - import decimal - from sqlalchemy import cast - from sqlalchemy import TypeDecorator - - class NumericMoney(TypeDecorator): - impl = MONEY - - def column_expression(self, column: Any): - return cast(column, Numeric()) - - .. versionadded:: 1.2 - - """ - - __visit_name__ = "MONEY" - - -class OID(sqltypes.TypeEngine[int]): - - """Provide the PostgreSQL OID type. - - .. versionadded:: 0.9.5 - - """ - - __visit_name__ = "OID" - - -class REGCLASS(sqltypes.TypeEngine[str]): - - """Provide the PostgreSQL REGCLASS type. - - .. versionadded:: 1.2.7 - - """ - - __visit_name__ = "REGCLASS" - - -class TIMESTAMP(sqltypes.TIMESTAMP): - def __init__(self, timezone=False, precision=None): - super(TIMESTAMP, self).__init__(timezone=timezone) - self.precision = precision - - -class TIME(sqltypes.TIME): - def __init__(self, timezone=False, precision=None): - super(TIME, self).__init__(timezone=timezone) - self.precision = precision - - -class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): - - """PostgreSQL INTERVAL type.""" - - __visit_name__ = "INTERVAL" - native = True - - def __init__(self, precision=None, fields=None): - """Construct an INTERVAL. - - :param precision: optional integer precision value - :param fields: string fields specifier. allows storage of fields - to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``, - etc. - - .. versionadded:: 1.2 - - """ - self.precision = precision - self.fields = fields - - @classmethod - def adapt_emulated_to_native(cls, interval, **kw): - return INTERVAL(precision=interval.second_precision) - - @property - def _type_affinity(self): - return sqltypes.Interval - - def as_generic(self, allow_nulltype=False): - return sqltypes.Interval(native=True, second_precision=self.precision) - - @property - def python_type(self): - return dt.timedelta - - -PGInterval = INTERVAL - - -class BIT(sqltypes.TypeEngine[int]): - __visit_name__ = "BIT" - - def __init__(self, length=None, varying=False): - if not varying: - # BIT without VARYING defaults to length 1 - self.length = length or 1 - else: - # but BIT VARYING can be unlimited-length, so no default - self.length = length - self.varying = varying - - -PGBit = BIT - - -class TSVECTOR(sqltypes.TypeEngine[Any]): - - """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL - text search type TSVECTOR. - - It can be used to do full text queries on natural language - documents. - - .. versionadded:: 0.9.0 - - .. seealso:: - - :ref:`postgresql_match` - - """ - - __visit_name__ = "TSVECTOR" - - -class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): - - """PostgreSQL ENUM type. - - This is a subclass of :class:`_types.Enum` which includes - support for PG's ``CREATE TYPE`` and ``DROP TYPE``. - - When the builtin type :class:`_types.Enum` is used and the - :paramref:`.Enum.native_enum` flag is left at its default of - True, the PostgreSQL backend will use a :class:`_postgresql.ENUM` - type as the implementation, so the special create/drop rules - will be used. - - The create/drop behavior of ENUM is necessarily intricate, due to the - awkward relationship the ENUM type has in relationship to the - parent table, in that it may be "owned" by just a single table, or - may be shared among many tables. - - When using :class:`_types.Enum` or :class:`_postgresql.ENUM` - in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted - corresponding to when the :meth:`_schema.Table.create` and - :meth:`_schema.Table.drop` - methods are called:: - - table = Table('sometable', metadata, - Column('some_enum', ENUM('a', 'b', 'c', name='myenum')) - ) - - table.create(engine) # will emit CREATE ENUM and CREATE TABLE - table.drop(engine) # will emit DROP TABLE and DROP ENUM - - To use a common enumerated type between multiple tables, the best - practice is to declare the :class:`_types.Enum` or - :class:`_postgresql.ENUM` independently, and associate it with the - :class:`_schema.MetaData` object itself:: - - my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata) - - t1 = Table('sometable_one', metadata, - Column('some_enum', myenum) - ) - - t2 = Table('sometable_two', metadata, - Column('some_enum', myenum) - ) - - When this pattern is used, care must still be taken at the level - of individual table creates. Emitting CREATE TABLE without also - specifying ``checkfirst=True`` will still cause issues:: - - t1.create(engine) # will fail: no such type 'myenum' - - If we specify ``checkfirst=True``, the individual table-level create - operation will check for the ``ENUM`` and create if not exists:: - - # will check if enum exists, and emit CREATE TYPE if not - t1.create(engine, checkfirst=True) - - When using a metadata-level ENUM type, the type will always be created - and dropped if either the metadata-wide create/drop is called:: - - metadata.create_all(engine) # will emit CREATE TYPE - metadata.drop_all(engine) # will emit DROP TYPE - - The type can also be created and dropped directly:: - - my_enum.create(engine) - my_enum.drop(engine) - - .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type - now behaves more strictly with regards to CREATE/DROP. A metadata-level - ENUM type will only be created and dropped at the metadata level, - not the table level, with the exception of - ``table.create(checkfirst=True)``. - The ``table.drop()`` call will now emit a DROP TYPE for a table-level - enumerated type. - - """ - - native_enum = True - - def __init__(self, *enums, **kw): - """Construct an :class:`_postgresql.ENUM`. - - Arguments are the same as that of - :class:`_types.Enum`, but also including - the following parameters. - - :param create_type: Defaults to True. - Indicates that ``CREATE TYPE`` should be - emitted, after optionally checking for the - presence of the type, when the parent - table is being created; and additionally - that ``DROP TYPE`` is called when the table - is dropped. When ``False``, no check - will be performed and no ``CREATE TYPE`` - or ``DROP TYPE`` is emitted, unless - :meth:`~.postgresql.ENUM.create` - or :meth:`~.postgresql.ENUM.drop` - are called directly. - Setting to ``False`` is helpful - when invoking a creation scheme to a SQL file - without access to the actual database - - the :meth:`~.postgresql.ENUM.create` and - :meth:`~.postgresql.ENUM.drop` methods can - be used to emit SQL to a target bind. - - """ - native_enum = kw.pop("native_enum", None) - if native_enum is False: - util.warn( - "the native_enum flag does not apply to the " - "sqlalchemy.dialects.postgresql.ENUM datatype; this type " - "always refers to ENUM. Use sqlalchemy.types.Enum for " - "non-native enum." - ) - self.create_type = kw.pop("create_type", True) - super(ENUM, self).__init__(*enums, **kw) - - @classmethod - def adapt_emulated_to_native(cls, impl, **kw): - """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain - :class:`.Enum`. - - """ - kw.setdefault("validate_strings", impl.validate_strings) - kw.setdefault("name", impl.name) - kw.setdefault("schema", impl.schema) - kw.setdefault("inherit_schema", impl.inherit_schema) - kw.setdefault("metadata", impl.metadata) - kw.setdefault("_create_events", False) - kw.setdefault("values_callable", impl.values_callable) - kw.setdefault("omit_aliases", impl._omit_aliases) - return cls(**kw) - - def create(self, bind=None, checkfirst=True): - """Emit ``CREATE TYPE`` for this - :class:`_postgresql.ENUM`. - - If the underlying dialect does not support - PostgreSQL CREATE TYPE, no action is taken. - - :param bind: a connectable :class:`_engine.Engine`, - :class:`_engine.Connection`, or similar object to emit - SQL. - :param checkfirst: if ``True``, a query against - the PG catalog will be first performed to see - if the type does not exist already before - creating. - - """ - if not bind.dialect.supports_native_enum: - return - - bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst) - - def drop(self, bind=None, checkfirst=True): - """Emit ``DROP TYPE`` for this - :class:`_postgresql.ENUM`. - - If the underlying dialect does not support - PostgreSQL DROP TYPE, no action is taken. - - :param bind: a connectable :class:`_engine.Engine`, - :class:`_engine.Connection`, or similar object to emit - SQL. - :param checkfirst: if ``True``, a query against - the PG catalog will be first performed to see - if the type actually exists before dropping. - - """ - if not bind.dialect.supports_native_enum: - return - - bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst) - - class EnumGenerator(InvokeDDLBase): - def __init__(self, dialect, connection, checkfirst=False, **kwargs): - super(ENUM.EnumGenerator, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - - def _can_create_enum(self, enum): - if not self.checkfirst: - return True - - effective_schema = self.connection.schema_for_object(enum) - - return not self.connection.dialect.has_type( - self.connection, enum.name, schema=effective_schema - ) - - def visit_enum(self, enum): - if not self._can_create_enum(enum): - return - - self.connection.execute(CreateEnumType(enum)) - - class EnumDropper(InvokeDDLBase): - def __init__(self, dialect, connection, checkfirst=False, **kwargs): - super(ENUM.EnumDropper, self).__init__(connection, **kwargs) - self.checkfirst = checkfirst - - def _can_drop_enum(self, enum): - if not self.checkfirst: - return True - - effective_schema = self.connection.schema_for_object(enum) - - return self.connection.dialect.has_type( - self.connection, enum.name, schema=effective_schema - ) - - def visit_enum(self, enum): - if not self._can_drop_enum(enum): - return - - self.connection.execute(DropEnumType(enum)) - - def get_dbapi_type(self, dbapi): - """dont return dbapi.STRING for ENUM in PostgreSQL, since that's - a different type""" - - return None - - def _check_for_name_in_memos(self, checkfirst, kw): - """Look in the 'ddl runner' for 'memos', then - note our name in that collection. - - This to ensure a particular named enum is operated - upon only once within any kind of create/drop - sequence without relying upon "checkfirst". - - """ - if not self.create_type: - return True - if "_ddl_runner" in kw: - ddl_runner = kw["_ddl_runner"] - if "_pg_enums" in ddl_runner.memo: - pg_enums = ddl_runner.memo["_pg_enums"] - else: - pg_enums = ddl_runner.memo["_pg_enums"] = set() - present = (self.schema, self.name) in pg_enums - pg_enums.add((self.schema, self.name)) - return present - else: - return False - - def _on_table_create(self, target, bind, checkfirst=False, **kw): - if ( - checkfirst - or ( - not self.metadata - and not kw.get("_is_metadata_operation", False) - ) - ) and not self._check_for_name_in_memos(checkfirst, kw): - self.create(bind=bind, checkfirst=checkfirst) - - def _on_table_drop(self, target, bind, checkfirst=False, **kw): - if ( - not self.metadata - and not kw.get("_is_metadata_operation", False) - and not self._check_for_name_in_memos(checkfirst, kw) - ): - self.drop(bind=bind, checkfirst=checkfirst) - - def _on_metadata_create(self, target, bind, checkfirst=False, **kw): - if not self._check_for_name_in_memos(checkfirst, kw): - self.create(bind=bind, checkfirst=checkfirst) - - def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): - if not self._check_for_name_in_memos(checkfirst, kw): - self.drop(bind=bind, checkfirst=checkfirst) - - colspecs = { sqltypes.ARRAY: _array.ARRAY, sqltypes.Interval: INTERVAL, @@ -2997,8 +2563,19 @@ class PGIdentifierPreparer(compiler.IdentifierPreparer): class PGInspector(reflection.Inspector): + dialect: PGDialect + def get_table_oid(self, table_name, schema=None): - """Return the OID for the given table name.""" + """Return the OID for the given table name. + + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + """ with self._operation_context() as conn: return self.dialect.get_table_oid( @@ -3023,9 +2600,10 @@ class PGInspector(reflection.Inspector): .. versionadded:: 1.0.0 """ - schema = schema or self.default_schema_name with self._operation_context() as conn: - return self.dialect._load_enums(conn, schema) + return self.dialect._load_enums( + conn, schema, info_cache=self.info_cache + ) def get_foreign_table_names(self, schema=None): """Return a list of FOREIGN TABLE names. @@ -3038,38 +2616,29 @@ class PGInspector(reflection.Inspector): .. versionadded:: 1.0.0 """ - schema = schema or self.default_schema_name with self._operation_context() as conn: - return self.dialect._get_foreign_table_names(conn, schema) - - def get_view_names(self, schema=None, include=("plain", "materialized")): - """Return all view names in `schema`. + return self.dialect._get_foreign_table_names( + conn, schema, info_cache=self.info_cache + ) - :param schema: Optional, retrieve names from a non-default schema. - For special quoting, use :class:`.quoted_name`. + def has_type(self, type_name, schema=None, **kw): + """Return if the database has the specified type in the provided + schema. - :param include: specify which types of views to return. Passed - as a string value (for a single type) or a tuple (for any number - of types). Defaults to ``('plain', 'materialized')``. + :param type_name: the type to check. + :param schema: schema name. If None, the default schema + (typically 'public') is used. May also be set to '*' to + check in all schemas. - .. versionadded:: 1.1 + .. versionadded:: 2.0 """ - with self._operation_context() as conn: - return self.dialect.get_view_names( - conn, schema, info_cache=self.info_cache, include=include + return self.dialect.has_type( + conn, type_name, schema, info_cache=self.info_cache ) -class CreateEnumType(schema._CreateDropBase): - __visit_name__ = "create_enum_type" - - -class DropEnumType(schema._CreateDropBase): - __visit_name__ = "drop_enum_type" - - class PGExecutionContext(default.DefaultExecutionContext): def fire_sequence(self, seq, type_): return self._execute_scalar( @@ -3262,35 +2831,14 @@ class PGDialect(default.DefaultDialect): def initialize(self, connection): super(PGDialect, self).initialize(connection) - if self.server_version_info <= (8, 2): - self.delete_returning = ( - self.update_returning - ) = self.insert_returning = False - - self.supports_native_enum = self.server_version_info >= (8, 3) - if not self.supports_native_enum: - self.colspecs = self.colspecs.copy() - # pop base Enum type - self.colspecs.pop(sqltypes.Enum, None) - # psycopg2, others may have placed ENUM here as well - self.colspecs.pop(ENUM, None) - # https://www.postgresql.org/docs/9.3/static/release-9-2.html#AEN116689 self.supports_smallserial = self.server_version_info >= (9, 2) - if self.server_version_info < (8, 2): - self._backslash_escapes = False - else: - # ensure this query is not emitted on server version < 8.2 - # as it will fail - std_string = connection.exec_driver_sql( - "show standard_conforming_strings" - ).scalar() - self._backslash_escapes = std_string == "off" - - self._supports_create_index_concurrently = ( - self.server_version_info >= (8, 2) - ) + std_string = connection.exec_driver_sql( + "show standard_conforming_strings" + ).scalar() + self._backslash_escapes = std_string == "off" + self._supports_drop_index_concurrently = self.server_version_info >= ( 9, 2, @@ -3370,122 +2918,100 @@ class PGDialect(default.DefaultDialect): self.do_commit(connection.connection) def do_recover_twophase(self, connection): - resultset = connection.execute( + return connection.scalars( sql.text("SELECT gid FROM pg_prepared_xacts") - ) - return [row[0] for row in resultset] + ).all() def _get_default_schema_name(self, connection): return connection.exec_driver_sql("select current_schema()").scalar() - def has_schema(self, connection, schema): - query = ( - "select nspname from pg_namespace " "where lower(nspname)=:schema" - ) - cursor = connection.execute( - sql.text(query).bindparams( - sql.bindparam( - "schema", - str(schema.lower()), - type_=sqltypes.Unicode, - ) - ) + @reflection.cache + def has_schema(self, connection, schema, **kw): + query = select(pg_catalog.pg_namespace.c.nspname).where( + pg_catalog.pg_namespace.c.nspname == schema ) + return bool(connection.scalar(query)) - return bool(cursor.first()) - - def has_table(self, connection, table_name, schema=None): - self._ensure_has_table_connection(connection) - # seems like case gets folded in pg_class... + def _pg_class_filter_scope_schema( + self, query, schema, scope, pg_class_table=None + ): + if pg_class_table is None: + pg_class_table = pg_catalog.pg_class + query = query.join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid == pg_class_table.c.relnamespace, + ) + if scope is ObjectScope.DEFAULT: + query = query.where(pg_class_table.c.relpersistence != "t") + elif scope is ObjectScope.TEMPORARY: + query = query.where(pg_class_table.c.relpersistence == "t") if schema is None: - cursor = connection.execute( - sql.text( - "select relname from pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where " - "pg_catalog.pg_table_is_visible(c.oid) " - "and relname=:name" - ).bindparams( - sql.bindparam( - "name", - str(table_name), - type_=sqltypes.Unicode, - ) - ) + query = query.where( + pg_catalog.pg_table_is_visible(pg_class_table.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", ) else: - cursor = connection.execute( - sql.text( - "select relname from pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where n.nspname=:schema and " - "relname=:name" - ).bindparams( - sql.bindparam( - "name", - str(table_name), - type_=sqltypes.Unicode, - ), - sql.bindparam( - "schema", - str(schema), - type_=sqltypes.Unicode, - ), - ) - ) - return bool(cursor.first()) - - def has_sequence(self, connection, sequence_name, schema=None): - if schema is None: - schema = self.default_schema_name - cursor = connection.execute( - sql.text( - "SELECT relname FROM pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where relkind='S' and " - "n.nspname=:schema and relname=:name" - ).bindparams( - sql.bindparam( - "name", - str(sequence_name), - type_=sqltypes.Unicode, - ), - sql.bindparam( - "schema", - str(schema), - type_=sqltypes.Unicode, - ), - ) + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + return query + + def _pg_class_relkind_condition(self, relkinds, pg_class_table=None): + if pg_class_table is None: + pg_class_table = pg_catalog.pg_class + # uses the any form instead of in otherwise postgresql complaings + # that 'IN could not convert type character to "char"' + return pg_class_table.c.relkind == sql.any_(_array.array(relkinds)) + + @lru_cache() + def _has_table_query(self, schema): + query = select(pg_catalog.pg_class.c.relname).where( + pg_catalog.pg_class.c.relname == bindparam("table_name"), + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_ALL_TABLE_LIKE + ), + ) + return self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY ) - return bool(cursor.first()) + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): + self._ensure_has_table_connection(connection) + query = self._has_table_query(schema) + return bool(connection.scalar(query, {"table_name": table_name})) - def has_type(self, connection, type_name, schema=None): - if schema is not None: - query = """ - SELECT EXISTS ( - SELECT * FROM pg_catalog.pg_type t, pg_catalog.pg_namespace n - WHERE t.typnamespace = n.oid - AND t.typname = :typname - AND n.nspname = :nspname - ) - """ - query = sql.text(query) - else: - query = """ - SELECT EXISTS ( - SELECT * FROM pg_catalog.pg_type t - WHERE t.typname = :typname - AND pg_type_is_visible(t.oid) - ) - """ - query = sql.text(query) - query = query.bindparams( - sql.bindparam("typname", str(type_name), type_=sqltypes.Unicode) + @reflection.cache + def has_sequence(self, connection, sequence_name, schema=None, **kw): + query = select(pg_catalog.pg_class.c.relname).where( + pg_catalog.pg_class.c.relkind == "S", + pg_catalog.pg_class.c.relname == sequence_name, ) - if schema is not None: - query = query.bindparams( - sql.bindparam("nspname", str(schema), type_=sqltypes.Unicode) + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + return bool(connection.scalar(query)) + + @reflection.cache + def has_type(self, connection, type_name, schema=None, **kw): + query = ( + select(pg_catalog.pg_type.c.typname) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, ) - cursor = connection.execute(query) - return bool(cursor.scalar()) + .where(pg_catalog.pg_type.c.typname == type_name) + ) + if schema is None: + query = query.where( + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", + ) + elif schema != "*": + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + + return bool(connection.scalar(query)) def _get_server_version_info(self, connection): v = connection.exec_driver_sql("select pg_catalog.version()").scalar() @@ -3502,229 +3028,300 @@ class PGDialect(default.DefaultDialect): @reflection.cache def get_table_oid(self, connection, table_name, schema=None, **kw): - """Fetch the oid for schema.table_name. - - Several reflection methods require the table oid. The idea for using - this method is that it can be fetched one time and cached for - subsequent calls. - - """ - table_oid = None - if schema is not None: - schema_where_clause = "n.nspname = :schema" - else: - schema_where_clause = "pg_catalog.pg_table_is_visible(c.oid)" - query = ( - """ - SELECT c.oid - FROM pg_catalog.pg_class c - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace - WHERE (%s) - AND c.relname = :table_name AND c.relkind in - ('r', 'v', 'm', 'f', 'p') - """ - % schema_where_clause + """Fetch the oid for schema.table_name.""" + query = select(pg_catalog.pg_class.c.oid).where( + pg_catalog.pg_class.c.relname == table_name, + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_ALL_TABLE_LIKE + ), ) - # Since we're binding to unicode, table_name and schema_name must be - # unicode. - table_name = str(table_name) - if schema is not None: - schema = str(schema) - s = sql.text(query).bindparams(table_name=sqltypes.Unicode) - s = s.columns(oid=sqltypes.Integer) - if schema: - s = s.bindparams(sql.bindparam("schema", type_=sqltypes.Unicode)) - c = connection.execute(s, dict(table_name=table_name, schema=schema)) - table_oid = c.scalar() + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + table_oid = connection.scalar(query) if table_oid is None: - raise exc.NoSuchTableError(table_name) + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) return table_oid @reflection.cache def get_schema_names(self, connection, **kw): - result = connection.execute( - sql.text( - "SELECT nspname FROM pg_namespace " - "WHERE nspname NOT LIKE 'pg_%' " - "ORDER BY nspname" - ).columns(nspname=sqltypes.Unicode) + query = ( + select(pg_catalog.pg_namespace.c.nspname) + .where(pg_catalog.pg_namespace.c.nspname.not_like("pg_%")) + .order_by(pg_catalog.pg_namespace.c.nspname) + ) + return connection.scalars(query).all() + + def _get_relnames_for_relkinds(self, connection, schema, relkinds, scope): + query = select(pg_catalog.pg_class.c.relname).where( + self._pg_class_relkind_condition(relkinds) ) - return [name for name, in result] + query = self._pg_class_filter_scope_schema(query, schema, scope=scope) + return connection.scalars(query).all() @reflection.cache def get_table_names(self, connection, schema=None, **kw): - result = connection.execute( - sql.text( - "SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind in ('r', 'p')" - ).columns(relname=sqltypes.Unicode), - dict( - schema=schema - if schema is not None - else self.default_schema_name - ), + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_TABLE_NO_FOREIGN, + scope=ObjectScope.DEFAULT, + ) + + @reflection.cache + def get_temp_table_names(self, connection, **kw): + return self._get_relnames_for_relkinds( + connection, + schema=None, + relkinds=pg_catalog.RELKINDS_TABLE_NO_FOREIGN, + scope=ObjectScope.TEMPORARY, ) - return [name for name, in result] @reflection.cache def _get_foreign_table_names(self, connection, schema=None, **kw): - result = connection.execute( - sql.text( - "SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind = 'f'" - ).columns(relname=sqltypes.Unicode), - dict( - schema=schema - if schema is not None - else self.default_schema_name - ), + return self._get_relnames_for_relkinds( + connection, schema, relkinds=("f",), scope=ObjectScope.ANY ) - return [name for name, in result] @reflection.cache - def get_view_names( - self, connection, schema=None, include=("plain", "materialized"), **kw - ): + def get_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_VIEW, + scope=ObjectScope.DEFAULT, + ) - include_kind = {"plain": "v", "materialized": "m"} - try: - kinds = [include_kind[i] for i in util.to_list(include)] - except KeyError: - raise ValueError( - "include %r unknown, needs to be a sequence containing " - "one or both of 'plain' and 'materialized'" % (include,) - ) - if not kinds: - raise ValueError( - "empty include, needs to be a sequence containing " - "one or both of 'plain' and 'materialized'" - ) + @reflection.cache + def get_materialized_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + pg_catalog.RELKINDS_MAT_VIEW, + scope=ObjectScope.DEFAULT, + ) - result = connection.execute( - sql.text( - "SELECT c.relname FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relkind IN (%s)" - % (", ".join("'%s'" % elem for elem in kinds)) - ).columns(relname=sqltypes.Unicode), - dict( - schema=schema - if schema is not None - else self.default_schema_name - ), + @reflection.cache + def get_temp_view_names(self, connection, schema=None, **kw): + return self._get_relnames_for_relkinds( + connection, + schema, + # NOTE: do not include temp materialzied views (that do not + # seem to be a thing at least up to version 14) + pg_catalog.RELKINDS_VIEW, + scope=ObjectScope.TEMPORARY, ) - return [name for name, in result] @reflection.cache def get_sequence_names(self, connection, schema=None, **kw): - if not schema: - schema = self.default_schema_name - cursor = connection.execute( - sql.text( - "SELECT relname FROM pg_class c join pg_namespace n on " - "n.oid=c.relnamespace where relkind='S' and " - "n.nspname=:schema" - ).bindparams( - sql.bindparam( - "schema", - str(schema), - type_=sqltypes.Unicode, - ), - ) + return self._get_relnames_for_relkinds( + connection, schema, relkinds=("S",), scope=ObjectScope.ANY ) - return [row[0] for row in cursor] @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): - view_def = connection.scalar( - sql.text( - "SELECT pg_get_viewdef(c.oid) view_def FROM pg_class c " - "JOIN pg_namespace n ON n.oid = c.relnamespace " - "WHERE n.nspname = :schema AND c.relname = :view_name " - "AND c.relkind IN ('v', 'm')" - ).columns(view_def=sqltypes.Unicode), - dict( - schema=schema - if schema is not None - else self.default_schema_name, - view_name=view_name, - ), + query = ( + select(pg_catalog.pg_get_viewdef(pg_catalog.pg_class.c.oid)) + .select_from(pg_catalog.pg_class) + .where( + pg_catalog.pg_class.c.relname == view_name, + self._pg_class_relkind_condition( + pg_catalog.RELKINDS_VIEW + pg_catalog.RELKINDS_MAT_VIEW + ), + ) ) - return view_def + query = self._pg_class_filter_scope_schema( + query, schema, scope=ObjectScope.ANY + ) + res = connection.scalar(query) + if res is None: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) + else: + return res + + def _value_or_raise(self, data, table, schema): + try: + return dict(data)[(schema, table)] + except KeyError: + raise exc.NoSuchTableError( + f"{schema}.{table}" if schema else table + ) from None + + def _prepare_filter_names(self, filter_names): + if filter_names: + return True, {"filter_names": filter_names} + else: + return False, {} + + def _kind_to_relkinds(self, kind: ObjectKind) -> tuple[str, ...]: + if kind is ObjectKind.ANY: + return pg_catalog.RELKINDS_ALL_TABLE_LIKE + relkinds = () + if ObjectKind.TABLE in kind: + relkinds += pg_catalog.RELKINDS_TABLE + if ObjectKind.VIEW in kind: + relkinds += pg_catalog.RELKINDS_VIEW + if ObjectKind.MATERIALIZED_VIEW in kind: + relkinds += pg_catalog.RELKINDS_MAT_VIEW + return relkinds @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): - - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_columns( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) + @lru_cache() + def _columns_query(self, schema, has_filter_names, scope, kind): + # NOTE: the query with the default and identity options scalar + # subquery is faster than trying to use outer joins for them generated = ( - "a.attgenerated as generated" + pg_catalog.pg_attribute.c.attgenerated.label("generated") if self.server_version_info >= (12,) - else "NULL as generated" + else sql.null().label("generated") ) if self.server_version_info >= (10,): - # a.attidentity != '' is required or it will reflect also - # serial columns as identity. - identity = """\ - (SELECT json_build_object( - 'always', a.attidentity = 'a', - 'start', s.seqstart, - 'increment', s.seqincrement, - 'minvalue', s.seqmin, - 'maxvalue', s.seqmax, - 'cache', s.seqcache, - 'cycle', s.seqcycle) - FROM pg_catalog.pg_sequence s - JOIN pg_catalog.pg_class c on s.seqrelid = c."oid" - WHERE c.relkind = 'S' - AND a.attidentity != '' - AND s.seqrelid = pg_catalog.pg_get_serial_sequence( - a.attrelid::regclass::text, a.attname - )::regclass::oid - ) as identity_options\ - """ + # join lateral performs worse (~2x slower) than a scalar_subquery + identity = ( + select( + sql.func.json_build_object( + "always", + pg_catalog.pg_attribute.c.attidentity == "a", + "start", + pg_catalog.pg_sequence.c.seqstart, + "increment", + pg_catalog.pg_sequence.c.seqincrement, + "minvalue", + pg_catalog.pg_sequence.c.seqmin, + "maxvalue", + pg_catalog.pg_sequence.c.seqmax, + "cache", + pg_catalog.pg_sequence.c.seqcache, + "cycle", + pg_catalog.pg_sequence.c.seqcycle, + ) + ) + .select_from(pg_catalog.pg_sequence) + .where( + # attidentity != '' is required or it will reflect also + # serial columns as identity. + pg_catalog.pg_attribute.c.attidentity != "", + pg_catalog.pg_sequence.c.seqrelid + == sql.cast( + sql.cast( + pg_catalog.pg_get_serial_sequence( + sql.cast( + sql.cast( + pg_catalog.pg_attribute.c.attrelid, + REGCLASS, + ), + TEXT, + ), + pg_catalog.pg_attribute.c.attname, + ), + REGCLASS, + ), + OID, + ), + ) + .correlate(pg_catalog.pg_attribute) + .scalar_subquery() + .label("identity_options") + ) else: - identity = "NULL as identity_options" - - SQL_COLS = """ - SELECT a.attname, - pg_catalog.format_type(a.atttypid, a.atttypmod), - ( - SELECT pg_catalog.pg_get_expr(d.adbin, d.adrelid) - FROM pg_catalog.pg_attrdef d - WHERE d.adrelid = a.attrelid AND d.adnum = a.attnum - AND a.atthasdef - ) AS DEFAULT, - a.attnotnull, - a.attrelid as table_oid, - pgd.description as comment, - %s, - %s - FROM pg_catalog.pg_attribute a - LEFT JOIN pg_catalog.pg_description pgd ON ( - pgd.objoid = a.attrelid AND pgd.objsubid = a.attnum) - WHERE a.attrelid = :table_oid - AND a.attnum > 0 AND NOT a.attisdropped - ORDER BY a.attnum - """ % ( - generated, - identity, + identity = sql.null().label("identity_options") + + # join lateral performs the same as scalar_subquery here + default = ( + select( + pg_catalog.pg_get_expr( + pg_catalog.pg_attrdef.c.adbin, + pg_catalog.pg_attrdef.c.adrelid, + ) + ) + .select_from(pg_catalog.pg_attrdef) + .where( + pg_catalog.pg_attrdef.c.adrelid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_attrdef.c.adnum + == pg_catalog.pg_attribute.c.attnum, + pg_catalog.pg_attribute.c.atthasdef, + ) + .correlate(pg_catalog.pg_attribute) + .scalar_subquery() + .label("default") ) - s = ( - sql.text(SQL_COLS) - .bindparams(sql.bindparam("table_oid", type_=sqltypes.Integer)) - .columns(attname=sqltypes.Unicode, default=sqltypes.Unicode) + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_attribute.c.attname.label("name"), + pg_catalog.format_type( + pg_catalog.pg_attribute.c.atttypid, + pg_catalog.pg_attribute.c.atttypmod, + ).label("format_type"), + default, + pg_catalog.pg_attribute.c.attnotnull.label("not_null"), + pg_catalog.pg_class.c.relname.label("table_name"), + pg_catalog.pg_description.c.description.label("comment"), + generated, + identity, + ) + .select_from(pg_catalog.pg_class) + # NOTE: postgresql support table with no user column, meaning + # there is no row with pg_attribute.attnum > 0. use a left outer + # join to avoid filtering these tables. + .outerjoin( + pg_catalog.pg_attribute, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_attribute.c.attnum > 0, + ~pg_catalog.pg_attribute.c.attisdropped, + ), + ) + .outerjoin( + pg_catalog.pg_description, + sql.and_( + pg_catalog.pg_description.c.objoid + == pg_catalog.pg_attribute.c.attrelid, + pg_catalog.pg_description.c.objsubid + == pg_catalog.pg_attribute.c.attnum, + ), + ) + .where(self._pg_class_relkind_condition(relkinds)) + .order_by( + pg_catalog.pg_class.c.relname, pg_catalog.pg_attribute.c.attnum + ) ) - c = connection.execute(s, dict(table_oid=table_oid)) - rows = c.fetchall() + query = self._pg_class_filter_scope_schema(query, schema, scope=scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query + + def get_multi_columns( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._columns_query(schema, has_filter_names, scope, kind) + rows = connection.execute(query, params).mappings() # dictionary with (name, ) if default search path or (schema, name) # as keys - domains = self._load_domains(connection) + domains = self._load_domains( + connection, info_cache=kw.get("info_cache") + ) # dictionary with (name, ) if default search path or (schema, name) # as keys @@ -3732,257 +3329,340 @@ class PGDialect(default.DefaultDialect): ((rec["name"],), rec) if rec["visible"] else ((rec["schema"], rec["name"]), rec) - for rec in self._load_enums(connection, schema="*") + for rec in self._load_enums( + connection, schema="*", info_cache=kw.get("info_cache") + ) ) - # format columns - columns = [] - - for ( - name, - format_type, - default_, - notnull, - table_oid, - comment, - generated, - identity, - ) in rows: - column_info = self._get_column_info( - name, - format_type, - default_, - notnull, - domains, - enums, - schema, - comment, - generated, - identity, - ) - columns.append(column_info) - return columns + columns = self._get_columns_info(rows, domains, enums, schema) + + return columns.items() + + def _get_columns_info(self, rows, domains, enums, schema): + array_type_pattern = re.compile(r"\[\]$") + attype_pattern = re.compile(r"\(.*\)") + charlen_pattern = re.compile(r"\(([\d,]+)\)") + args_pattern = re.compile(r"\((.*)\)") + args_split_pattern = re.compile(r"\s*,\s*") - def _get_column_info( - self, - name, - format_type, - default, - notnull, - domains, - enums, - schema, - comment, - generated, - identity, - ): def _handle_array_type(attype): return ( # strip '[]' from integer[], etc. - re.sub(r"\[\]$", "", attype), + array_type_pattern.sub("", attype), attype.endswith("[]"), ) - # strip (*) from character varying(5), timestamp(5) - # with time zone, geometry(POLYGON), etc. - attype = re.sub(r"\(.*\)", "", format_type) + columns = defaultdict(list) + for row_dict in rows: + # ensure that each table has an entry, even if it has no columns + if row_dict["name"] is None: + columns[ + (schema, row_dict["table_name"]) + ] = ReflectionDefaults.columns() + continue + table_cols = columns[(schema, row_dict["table_name"])] - # strip '[]' from integer[], etc. and check if an array - attype, is_array = _handle_array_type(attype) + format_type = row_dict["format_type"] + default = row_dict["default"] + name = row_dict["name"] + generated = row_dict["generated"] + identity = row_dict["identity_options"] - # strip quotes from case sensitive enum or domain names - enum_or_domain_key = tuple(util.quoted_token_parser(attype)) + # strip (*) from character varying(5), timestamp(5) + # with time zone, geometry(POLYGON), etc. + attype = attype_pattern.sub("", format_type) - nullable = not notnull + # strip '[]' from integer[], etc. and check if an array + attype, is_array = _handle_array_type(attype) - charlen = re.search(r"\(([\d,]+)\)", format_type) - if charlen: - charlen = charlen.group(1) - args = re.search(r"\((.*)\)", format_type) - if args and args.group(1): - args = tuple(re.split(r"\s*,\s*", args.group(1))) - else: - args = () - kwargs = {} + # strip quotes from case sensitive enum or domain names + enum_or_domain_key = tuple(util.quoted_token_parser(attype)) + + nullable = not row_dict["not_null"] - if attype == "numeric": + charlen = charlen_pattern.search(format_type) if charlen: - prec, scale = charlen.split(",") - args = (int(prec), int(scale)) + charlen = charlen.group(1) + args = args_pattern.search(format_type) + if args and args.group(1): + args = tuple(args_split_pattern.split(args.group(1))) else: args = () - elif attype == "double precision": - args = (53,) - elif attype == "integer": - args = () - elif attype in ("timestamp with time zone", "time with time zone"): - kwargs["timezone"] = True - if charlen: - kwargs["precision"] = int(charlen) - args = () - elif attype in ( - "timestamp without time zone", - "time without time zone", - "time", - ): - kwargs["timezone"] = False - if charlen: - kwargs["precision"] = int(charlen) - args = () - elif attype == "bit varying": - kwargs["varying"] = True - if charlen: + kwargs = {} + + if attype == "numeric": + if charlen: + prec, scale = charlen.split(",") + args = (int(prec), int(scale)) + else: + args = () + elif attype == "double precision": + args = (53,) + elif attype == "integer": + args = () + elif attype in ("timestamp with time zone", "time with time zone"): + kwargs["timezone"] = True + if charlen: + kwargs["precision"] = int(charlen) + args = () + elif attype in ( + "timestamp without time zone", + "time without time zone", + "time", + ): + kwargs["timezone"] = False + if charlen: + kwargs["precision"] = int(charlen) + args = () + elif attype == "bit varying": + kwargs["varying"] = True + if charlen: + args = (int(charlen),) + else: + args = () + elif attype.startswith("interval"): + field_match = re.match(r"interval (.+)", attype, re.I) + if charlen: + kwargs["precision"] = int(charlen) + if field_match: + kwargs["fields"] = field_match.group(1) + attype = "interval" + args = () + elif charlen: args = (int(charlen),) + + while True: + # looping here to suit nested domains + if attype in self.ischema_names: + coltype = self.ischema_names[attype] + break + elif enum_or_domain_key in enums: + enum = enums[enum_or_domain_key] + coltype = ENUM + kwargs["name"] = enum["name"] + if not enum["visible"]: + kwargs["schema"] = enum["schema"] + args = tuple(enum["labels"]) + break + elif enum_or_domain_key in domains: + domain = domains[enum_or_domain_key] + attype = domain["attype"] + attype, is_array = _handle_array_type(attype) + # strip quotes from case sensitive enum or domain names + enum_or_domain_key = tuple( + util.quoted_token_parser(attype) + ) + # A table can't override a not null on the domain, + # but can override nullable + nullable = nullable and domain["nullable"] + if domain["default"] and not default: + # It can, however, override the default + # value, but can't set it to null. + default = domain["default"] + continue + else: + coltype = None + break + + if coltype: + coltype = coltype(*args, **kwargs) + if is_array: + coltype = self.ischema_names["_array"](coltype) else: - args = () - elif attype.startswith("interval"): - field_match = re.match(r"interval (.+)", attype, re.I) - if charlen: - kwargs["precision"] = int(charlen) - if field_match: - kwargs["fields"] = field_match.group(1) - attype = "interval" - args = () - elif charlen: - args = (int(charlen),) - - while True: - # looping here to suit nested domains - if attype in self.ischema_names: - coltype = self.ischema_names[attype] - break - elif enum_or_domain_key in enums: - enum = enums[enum_or_domain_key] - coltype = ENUM - kwargs["name"] = enum["name"] - if not enum["visible"]: - kwargs["schema"] = enum["schema"] - args = tuple(enum["labels"]) - break - elif enum_or_domain_key in domains: - domain = domains[enum_or_domain_key] - attype = domain["attype"] - attype, is_array = _handle_array_type(attype) - # strip quotes from case sensitive enum or domain names - enum_or_domain_key = tuple(util.quoted_token_parser(attype)) - # A table can't override a not null on the domain, - # but can override nullable - nullable = nullable and domain["nullable"] - if domain["default"] and not default: - # It can, however, override the default - # value, but can't set it to null. - default = domain["default"] - continue + util.warn( + "Did not recognize type '%s' of column '%s'" + % (attype, name) + ) + coltype = sqltypes.NULLTYPE + + # If a zero byte or blank string depending on driver (is also + # absent for older PG versions), then not a generated column. + # Otherwise, s = stored. (Other values might be added in the + # future.) + if generated not in (None, "", b"\x00"): + computed = dict( + sqltext=default, persisted=generated in ("s", b"s") + ) + default = None else: - coltype = None - break + computed = None - if coltype: - coltype = coltype(*args, **kwargs) - if is_array: - coltype = self.ischema_names["_array"](coltype) - else: - util.warn( - "Did not recognize type '%s' of column '%s'" % (attype, name) + # adjust the default value + autoincrement = False + if default is not None: + match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) + if match is not None: + if issubclass(coltype._type_affinity, sqltypes.Integer): + autoincrement = True + # the default is related to a Sequence + if "." not in match.group(2) and schema is not None: + # unconditionally quote the schema name. this could + # later be enhanced to obey quoting rules / + # "quote schema" + default = ( + match.group(1) + + ('"%s"' % schema) + + "." + + match.group(2) + + match.group(3) + ) + + column_info = { + "name": name, + "type": coltype, + "nullable": nullable, + "default": default, + "autoincrement": autoincrement or identity is not None, + "comment": row_dict["comment"], + } + if computed is not None: + column_info["computed"] = computed + if identity is not None: + column_info["identity"] = identity + + table_cols.append(column_info) + + return columns + + @lru_cache() + def _table_oids_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + oid_q = select( + pg_catalog.pg_class.c.oid, pg_catalog.pg_class.c.relname + ).where(self._pg_class_relkind_condition(relkinds)) + oid_q = self._pg_class_filter_scope_schema(oid_q, schema, scope=scope) + + if has_filter_names: + oid_q = oid_q.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) ) - coltype = sqltypes.NULLTYPE - - # If a zero byte or blank string depending on driver (is also absent - # for older PG versions), then not a generated column. Otherwise, s = - # stored. (Other values might be added in the future.) - if generated not in (None, "", b"\x00"): - computed = dict( - sqltext=default, persisted=generated in ("s", b"s") + return oid_q + + @reflection.flexi_cache( + ("schema", InternalTraversal.dp_string), + ("filter_names", InternalTraversal.dp_string_list), + ("kind", InternalTraversal.dp_plain_obj), + ("scope", InternalTraversal.dp_plain_obj), + ) + def _get_table_oids( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + oid_q = self._table_oids_query(schema, has_filter_names, scope, kind) + result = connection.execute(oid_q, params) + return result.all() + + @util.memoized_property + def _constraint_query(self): + con_sq = ( + select( + pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.conname, + sql.func.unnest(pg_catalog.pg_constraint.c.conkey).label( + "attnum" + ), + sql.func.generate_subscripts( + pg_catalog.pg_constraint.c.conkey, 1 + ).label("ord"), ) - default = None - else: - computed = None - - # adjust the default value - autoincrement = False - if default is not None: - match = re.search(r"""(nextval\(')([^']+)('.*$)""", default) - if match is not None: - if issubclass(coltype._type_affinity, sqltypes.Integer): - autoincrement = True - # the default is related to a Sequence - sch = schema - if "." not in match.group(2) and sch is not None: - # unconditionally quote the schema name. this could - # later be enhanced to obey quoting rules / - # "quote schema" - default = ( - match.group(1) - + ('"%s"' % sch) - + "." - + match.group(2) - + match.group(3) - ) + .where( + pg_catalog.pg_constraint.c.contype == bindparam("contype"), + pg_catalog.pg_constraint.c.conrelid.in_(bindparam("oids")), + ) + .subquery("con") + ) - column_info = dict( - name=name, - type=coltype, - nullable=nullable, - default=default, - autoincrement=autoincrement or identity is not None, - comment=comment, + attr_sq = ( + select( + con_sq.c.conrelid, + con_sq.c.conname, + pg_catalog.pg_attribute.c.attname, + ) + .select_from(pg_catalog.pg_attribute) + .join( + con_sq, + sql.and_( + pg_catalog.pg_attribute.c.attnum == con_sq.c.attnum, + pg_catalog.pg_attribute.c.attrelid == con_sq.c.conrelid, + ), + ) + .order_by(con_sq.c.conname, con_sq.c.ord) + .subquery("attr") ) - if computed is not None: - column_info["computed"] = computed - if identity is not None: - column_info["identity"] = identity - return column_info - @reflection.cache - def get_pk_constraint(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + return ( + select( + attr_sq.c.conrelid, + sql.func.array_agg(attr_sq.c.attname).label("cols"), + attr_sq.c.conname, + ) + .group_by(attr_sq.c.conrelid, attr_sq.c.conname) + .order_by(attr_sq.c.conrelid, attr_sq.c.conname) ) - if self.server_version_info < (8, 4): - PK_SQL = """ - SELECT a.attname - FROM - pg_class t - join pg_index ix on t.oid = ix.indrelid - join pg_attribute a - on t.oid=a.attrelid AND %s - WHERE - t.oid = :table_oid and ix.indisprimary = 't' - ORDER BY a.attnum - """ % self._pg_index_any( - "a.attnum", "ix.indkey" + def _reflect_constraint( + self, connection, contype, schema, filter_names, scope, kind, **kw + ): + table_oids = self._get_table_oids( + connection, schema, filter_names, scope, kind, **kw + ) + batches = list(table_oids) + + while batches: + batch = batches[0:3000] + batches[0:3000] = [] + + result = connection.execute( + self._constraint_query, + {"oids": [r[0] for r in batch], "contype": contype}, ) - else: - # unnest() and generate_subscripts() both introduced in - # version 8.4 - PK_SQL = """ - SELECT a.attname - FROM pg_attribute a JOIN ( - SELECT unnest(ix.indkey) attnum, - generate_subscripts(ix.indkey, 1) ord - FROM pg_index ix - WHERE ix.indrelid = :table_oid AND ix.indisprimary - ) k ON a.attnum=k.attnum - WHERE a.attrelid = :table_oid - ORDER BY k.ord - """ - t = sql.text(PK_SQL).columns(attname=sqltypes.Unicode) - c = connection.execute(t, dict(table_oid=table_oid)) - cols = [r[0] for r in c.fetchall()] - - PK_CONS_SQL = """ - SELECT conname - FROM pg_catalog.pg_constraint r - WHERE r.conrelid = :table_oid AND r.contype = 'p' - ORDER BY 1 - """ - t = sql.text(PK_CONS_SQL).columns(conname=sqltypes.Unicode) - c = connection.execute(t, dict(table_oid=table_oid)) - name = c.scalar() + result_by_oid = defaultdict(list) + for oid, cols, constraint_name in result: + result_by_oid[oid].append((cols, constraint_name)) + + for oid, tablename in batch: + for_oid = result_by_oid.get(oid, ()) + if for_oid: + for cols, constraint in for_oid: + yield tablename, cols, constraint + else: + yield tablename, None, None + + @reflection.cache + def get_pk_constraint(self, connection, table_name, schema=None, **kw): + data = self.get_multi_pk_constraint( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, + ) + return self._value_or_raise(data, table_name, schema) + + def get_multi_pk_constraint( + self, connection, schema, filter_names, scope, kind, **kw + ): + result = self._reflect_constraint( + connection, "p", schema, filter_names, scope, kind, **kw + ) - return {"constrained_columns": cols, "name": name} + # only a single pk can be present for each table. Return an entry + # even if a table has no primary key + default = ReflectionDefaults.pk_constraint + return ( + ( + (schema, table_name), + { + "constrained_columns": [] if cols is None else cols, + "name": pk_name, + } + if pk_name is not None + else default(), + ) + for (table_name, cols, pk_name) in result + ) @reflection.cache def get_foreign_keys( @@ -3993,27 +3673,71 @@ class PGDialect(default.DefaultDialect): postgresql_ignore_search_path=False, **kw, ): - preparer = self.identifier_preparer - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_foreign_keys( + connection, + schema=schema, + filter_names=[table_name], + postgresql_ignore_search_path=postgresql_ignore_search_path, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - FK_SQL = """ - SELECT r.conname, - pg_catalog.pg_get_constraintdef(r.oid, true) as condef, - n.nspname as conschema - FROM pg_catalog.pg_constraint r, - pg_namespace n, - pg_class c - - WHERE r.conrelid = :table AND - r.contype = 'f' AND - c.oid = confrelid AND - n.oid = c.relnamespace - ORDER BY 1 - """ - # https://www.postgresql.org/docs/9.0/static/sql-createtable.html - FK_REGEX = re.compile( + @lru_cache() + def _foreing_key_query(self, schema, has_filter_names, scope, kind): + pg_class_ref = pg_catalog.pg_class.alias("cls_ref") + pg_namespace_ref = pg_catalog.pg_namespace.alias("nsp_ref") + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + sql.case( + ( + pg_catalog.pg_constraint.c.oid.is_not(None), + pg_catalog.pg_get_constraintdef( + pg_catalog.pg_constraint.c.oid, True + ), + ), + else_=None, + ), + pg_namespace_ref.c.nspname, + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.contype == "f", + ), + ) + .outerjoin( + pg_class_ref, + pg_class_ref.c.oid == pg_catalog.pg_constraint.c.confrelid, + ) + .outerjoin( + pg_namespace_ref, + pg_class_ref.c.relnamespace == pg_namespace_ref.c.oid, + ) + .order_by( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query + + @util.memoized_property + def _fk_regex_pattern(self): + # https://www.postgresql.org/docs/14.0/static/sql-createtable.html + return re.compile( r"FOREIGN KEY \((.*?)\) REFERENCES (?:(.*?)\.)?(.*?)\((.*?)\)" r"[\s]?(MATCH (FULL|PARTIAL|SIMPLE)+)?" r"[\s]?(ON UPDATE " @@ -4024,12 +3748,33 @@ class PGDialect(default.DefaultDialect): r"[\s]?(INITIALLY (DEFERRED|IMMEDIATE)+)?" ) - t = sql.text(FK_SQL).columns( - conname=sqltypes.Unicode, condef=sqltypes.Unicode - ) - c = connection.execute(t, dict(table=table_oid)) - fkeys = [] - for conname, condef, conschema in c.fetchall(): + def get_multi_foreign_keys( + self, + connection, + schema, + filter_names, + scope, + kind, + postgresql_ignore_search_path=False, + **kw, + ): + preparer = self.identifier_preparer + + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._foreing_key_query(schema, has_filter_names, scope, kind) + result = connection.execute(query, params) + + FK_REGEX = self._fk_regex_pattern + + fkeys = defaultdict(list) + default = ReflectionDefaults.foreign_keys + for table_name, conname, condef, conschema in result: + # ensure that each table has an entry, even if it has + # no foreign keys + if conname is None: + fkeys[(schema, table_name)] = default() + continue + table_fks = fkeys[(schema, table_name)] m = re.search(FK_REGEX, condef).groups() ( @@ -4096,317 +3841,406 @@ class PGDialect(default.DefaultDialect): "referred_columns": referred_columns, "options": options, } - fkeys.append(fkey_d) - return fkeys - - def _pg_index_any(self, col, compare_to): - if self.server_version_info < (8, 1): - # https://www.postgresql.org/message-id/10279.1124395722@sss.pgh.pa.us - # "In CVS tip you could replace this with "attnum = ANY (indkey)". - # Unfortunately, most array support doesn't work on int2vector in - # pre-8.1 releases, so I think you're kinda stuck with the above - # for now. - # regards, tom lane" - return "(%s)" % " OR ".join( - "%s[%d] = %s" % (compare_to, ind, col) for ind in range(0, 10) - ) - else: - return "%s = ANY(%s)" % (col, compare_to) + table_fks.append(fkey_d) + return fkeys.items() @reflection.cache - def get_indexes(self, connection, table_name, schema, **kw): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + def get_indexes(self, connection, table_name, schema=None, **kw): + data = self.get_multi_indexes( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - # cast indkey as varchar since it's an int2vector, - # returned as a list by some drivers such as pypostgresql - - if self.server_version_info < (8, 5): - IDX_SQL = """ - SELECT - i.relname as relname, - ix.indisunique, ix.indexprs, ix.indpred, - a.attname, a.attnum, NULL, ix.indkey%s, - %s, %s, am.amname, - NULL as indnkeyatts - FROM - pg_class t - join pg_index ix on t.oid = ix.indrelid - join pg_class i on i.oid = ix.indexrelid - left outer join - pg_attribute a - on t.oid = a.attrelid and %s - left outer join - pg_am am - on i.relam = am.oid - WHERE - t.relkind IN ('r', 'v', 'f', 'm') - and t.oid = :table_oid - and ix.indisprimary = 'f' - ORDER BY - t.relname, - i.relname - """ % ( - # version 8.3 here was based on observing the - # cast does not work in PG 8.2.4, does work in 8.3.0. - # nothing in PG changelogs regarding this. - "::varchar" if self.server_version_info >= (8, 3) else "", - "ix.indoption::varchar" - if self.server_version_info >= (8, 3) - else "NULL", - "i.reloptions" - if self.server_version_info >= (8, 2) - else "NULL", - self._pg_index_any("a.attnum", "ix.indkey"), + @util.memoized_property + def _index_query(self): + pg_class_index = pg_catalog.pg_class.alias("cls_idx") + # NOTE: repeating oids clause improve query performance + + # subquery to get the columns + idx_sq = ( + select( + pg_catalog.pg_index.c.indexrelid, + pg_catalog.pg_index.c.indrelid, + sql.func.unnest(pg_catalog.pg_index.c.indkey).label("attnum"), + sql.func.generate_subscripts( + pg_catalog.pg_index.c.indkey, 1 + ).label("ord"), ) - else: - IDX_SQL = """ - SELECT - i.relname as relname, - ix.indisunique, ix.indexprs, - a.attname, a.attnum, c.conrelid, ix.indkey::varchar, - ix.indoption::varchar, i.reloptions, am.amname, - pg_get_expr(ix.indpred, ix.indrelid), - %s as indnkeyatts - FROM - pg_class t - join pg_index ix on t.oid = ix.indrelid - join pg_class i on i.oid = ix.indexrelid - left outer join - pg_attribute a - on t.oid = a.attrelid and a.attnum = ANY(ix.indkey) - left outer join - pg_constraint c - on (ix.indrelid = c.conrelid and - ix.indexrelid = c.conindid and - c.contype in ('p', 'u', 'x')) - left outer join - pg_am am - on i.relam = am.oid - WHERE - t.relkind IN ('r', 'v', 'f', 'm', 'p') - and t.oid = :table_oid - and ix.indisprimary = 'f' - ORDER BY - t.relname, - i.relname - """ % ( - "ix.indnkeyatts" - if self.server_version_info >= (11, 0) - else "NULL", + .where( + ~pg_catalog.pg_index.c.indisprimary, + pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")), ) + .subquery("idx") + ) - t = sql.text(IDX_SQL).columns( - relname=sqltypes.Unicode, attname=sqltypes.Unicode + attr_sq = ( + select( + idx_sq.c.indexrelid, + idx_sq.c.indrelid, + pg_catalog.pg_attribute.c.attname, + ) + .select_from(pg_catalog.pg_attribute) + .join( + idx_sq, + sql.and_( + pg_catalog.pg_attribute.c.attnum == idx_sq.c.attnum, + pg_catalog.pg_attribute.c.attrelid == idx_sq.c.indrelid, + ), + ) + .where(idx_sq.c.indrelid.in_(bindparam("oids"))) + .order_by(idx_sq.c.indexrelid, idx_sq.c.ord) + .subquery("idx_attr") ) - c = connection.execute(t, dict(table_oid=table_oid)) - indexes = defaultdict(lambda: defaultdict(dict)) + cols_sq = ( + select( + attr_sq.c.indexrelid, + attr_sq.c.indrelid, + sql.func.array_agg(attr_sq.c.attname).label("cols"), + ) + .group_by(attr_sq.c.indexrelid, attr_sq.c.indrelid) + .subquery("idx_cols") + ) - sv_idx_name = None - for row in c.fetchall(): - ( - idx_name, - unique, - expr, - col, - col_num, - conrelid, - idx_key, - idx_option, - options, - amname, - filter_definition, - indnkeyatts, - ) = row + if self.server_version_info >= (11, 0): + indnkeyatts = pg_catalog.pg_index.c.indnkeyatts + else: + indnkeyatts = sql.null().label("indnkeyatts") - if expr: - if idx_name != sv_idx_name: - util.warn( - "Skipped unsupported reflection of " - "expression-based index %s" % idx_name - ) - sv_idx_name = idx_name - continue + query = ( + select( + pg_catalog.pg_index.c.indrelid, + pg_class_index.c.relname.label("relname_index"), + pg_catalog.pg_index.c.indisunique, + pg_catalog.pg_index.c.indexprs, + pg_catalog.pg_constraint.c.conrelid.is_not(None).label( + "has_constraint" + ), + pg_catalog.pg_index.c.indoption, + pg_class_index.c.reloptions, + pg_catalog.pg_am.c.amname, + pg_catalog.pg_get_expr( + pg_catalog.pg_index.c.indpred, + pg_catalog.pg_index.c.indrelid, + ).label("filter_definition"), + indnkeyatts, + cols_sq.c.cols.label("index_cols"), + ) + .select_from(pg_catalog.pg_index) + .where( + pg_catalog.pg_index.c.indrelid.in_(bindparam("oids")), + ~pg_catalog.pg_index.c.indisprimary, + ) + .join( + pg_class_index, + pg_catalog.pg_index.c.indexrelid == pg_class_index.c.oid, + ) + .join( + pg_catalog.pg_am, + pg_class_index.c.relam == pg_catalog.pg_am.c.oid, + ) + .outerjoin( + cols_sq, + pg_catalog.pg_index.c.indexrelid == cols_sq.c.indexrelid, + ) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_index.c.indrelid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_index.c.indexrelid + == pg_catalog.pg_constraint.c.conindid, + pg_catalog.pg_constraint.c.contype + == sql.any_(_array.array(("p", "u", "x"))), + ), + ) + .order_by(pg_catalog.pg_index.c.indrelid, pg_class_index.c.relname) + ) + return query - has_idx = idx_name in indexes - index = indexes[idx_name] - if col is not None: - index["cols"][col_num] = col - if not has_idx: - idx_keys = idx_key.split() - # "The number of key columns in the index, not counting any - # included columns, which are merely stored and do not - # participate in the index semantics" - if indnkeyatts and idx_keys[indnkeyatts:]: - # this is a "covering index" which has INCLUDE columns - # as well as regular index columns - inc_keys = idx_keys[indnkeyatts:] - idx_keys = idx_keys[:indnkeyatts] - else: - inc_keys = [] + def get_multi_indexes( + self, connection, schema, filter_names, scope, kind, **kw + ): - index["key"] = [int(k.strip()) for k in idx_keys] - index["inc"] = [int(k.strip()) for k in inc_keys] + table_oids = self._get_table_oids( + connection, schema, filter_names, scope, kind, **kw + ) - # (new in pg 8.3) - # "pg_index.indoption" is list of ints, one per column/expr. - # int acts as bitmask: 0x01=DESC, 0x02=NULLSFIRST - sorting = {} - for col_idx, col_flags in enumerate( - (idx_option or "").split() - ): - col_flags = int(col_flags.strip()) - col_sorting = () - # try to set flags only if they differ from PG defaults... - if col_flags & 0x01: - col_sorting += ("desc",) - if not (col_flags & 0x02): - col_sorting += ("nulls_last",) + indexes = defaultdict(list) + default = ReflectionDefaults.indexes + + batches = list(table_oids) + + while batches: + batch = batches[0:3000] + batches[0:3000] = [] + + result = connection.execute( + self._index_query, {"oids": [r[0] for r in batch]} + ).mappings() + + result_by_oid = defaultdict(list) + for row_dict in result: + result_by_oid[row_dict["indrelid"]].append(row_dict) + + for oid, table_name in batch: + if oid not in result_by_oid: + # ensure that each table has an entry, even if reflection + # is skipped because not supported + indexes[(schema, table_name)] = default() + continue + + for row in result_by_oid[oid]: + index_name = row["relname_index"] + + table_indexes = indexes[(schema, table_name)] + + if row["indexprs"]: + tn = ( + table_name + if schema is None + else f"{schema}.{table_name}" + ) + util.warn( + "Skipped unsupported reflection of " + f"expression-based index {index_name} of " + f"table {tn}" + ) + continue + + all_cols = row["index_cols"] + indnkeyatts = row["indnkeyatts"] + # "The number of key columns in the index, not counting any + # included columns, which are merely stored and do not + # participate in the index semantics" + if indnkeyatts and all_cols[indnkeyatts:]: + # this is a "covering index" which has INCLUDE columns + # as well as regular index columns + inc_cols = all_cols[indnkeyatts:] + idx_cols = all_cols[:indnkeyatts] else: - if col_flags & 0x02: - col_sorting += ("nulls_first",) - if col_sorting: - sorting[col_idx] = col_sorting - if sorting: - index["sorting"] = sorting - - index["unique"] = unique - if conrelid is not None: - index["duplicates_constraint"] = idx_name - if options: - index["options"] = dict( - [option.split("=") for option in options] - ) - - # it *might* be nice to include that this is 'btree' in the - # reflection info. But we don't want an Index object - # to have a ``postgresql_using`` in it that is just the - # default, so for the moment leaving this out. - if amname and amname != "btree": - index["amname"] = amname - - if filter_definition: - index["postgresql_where"] = filter_definition + idx_cols = all_cols + inc_cols = [] + + index = { + "name": index_name, + "unique": row["indisunique"], + "column_names": idx_cols, + } + + sorting = {} + for col_index, col_flags in enumerate(row["indoption"]): + col_sorting = () + # try to set flags only if they differ from PG + # defaults... + if col_flags & 0x01: + col_sorting += ("desc",) + if not (col_flags & 0x02): + col_sorting += ("nulls_last",) + else: + if col_flags & 0x02: + col_sorting += ("nulls_first",) + if col_sorting: + sorting[idx_cols[col_index]] = col_sorting + if sorting: + index["column_sorting"] = sorting + if row["has_constraint"]: + index["duplicates_constraint"] = index_name + + dialect_options = {} + if row["reloptions"]: + dialect_options["postgresql_with"] = dict( + [option.split("=") for option in row["reloptions"]] + ) + # it *might* be nice to include that this is 'btree' in the + # reflection info. But we don't want an Index object + # to have a ``postgresql_using`` in it that is just the + # default, so for the moment leaving this out. + amname = row["amname"] + if amname != "btree": + dialect_options["postgresql_using"] = row["amname"] + if row["filter_definition"]: + dialect_options["postgresql_where"] = row[ + "filter_definition" + ] + if self.server_version_info >= (11, 0): + # NOTE: this is legacy, this is part of + # dialect_options now as of #7382 + index["include_columns"] = inc_cols + dialect_options["postgresql_include"] = inc_cols + if dialect_options: + index["dialect_options"] = dialect_options - result = [] - for name, idx in indexes.items(): - entry = { - "name": name, - "unique": idx["unique"], - "column_names": [idx["cols"][i] for i in idx["key"]], - } - if self.server_version_info >= (11, 0): - # NOTE: this is legacy, this is part of dialect_options now - # as of #7382 - entry["include_columns"] = [idx["cols"][i] for i in idx["inc"]] - if "duplicates_constraint" in idx: - entry["duplicates_constraint"] = idx["duplicates_constraint"] - if "sorting" in idx: - entry["column_sorting"] = dict( - (idx["cols"][idx["key"][i]], value) - for i, value in idx["sorting"].items() - ) - if "include_columns" in entry: - entry.setdefault("dialect_options", {})[ - "postgresql_include" - ] = entry["include_columns"] - if "options" in idx: - entry.setdefault("dialect_options", {})[ - "postgresql_with" - ] = idx["options"] - if "amname" in idx: - entry.setdefault("dialect_options", {})[ - "postgresql_using" - ] = idx["amname"] - if "postgresql_where" in idx: - entry.setdefault("dialect_options", {})[ - "postgresql_where" - ] = idx["postgresql_where"] - result.append(entry) - return result + table_indexes.append(index) + return indexes.items() @reflection.cache def get_unique_constraints( self, connection, table_name, schema=None, **kw ): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_unique_constraints( + connection, + schema=schema, + filter_names=[table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - UNIQUE_SQL = """ - SELECT - cons.conname as name, - cons.conkey as key, - a.attnum as col_num, - a.attname as col_name - FROM - pg_catalog.pg_constraint cons - join pg_attribute a - on cons.conrelid = a.attrelid AND - a.attnum = ANY(cons.conkey) - WHERE - cons.conrelid = :table_oid AND - cons.contype = 'u' - """ - - t = sql.text(UNIQUE_SQL).columns(col_name=sqltypes.Unicode) - c = connection.execute(t, dict(table_oid=table_oid)) + def get_multi_unique_constraints( + self, + connection, + schema, + filter_names, + scope, + kind, + **kw, + ): + result = self._reflect_constraint( + connection, "u", schema, filter_names, scope, kind, **kw + ) - uniques = defaultdict(lambda: defaultdict(dict)) - for row in c.fetchall(): - uc = uniques[row.name] - uc["key"] = row.key - uc["cols"][row.col_num] = row.col_name + # each table can have multiple unique constraints + uniques = defaultdict(list) + default = ReflectionDefaults.unique_constraints + for (table_name, cols, con_name) in result: + # ensure a list is created for each table. leave it empty if + # the table has no unique cosntraint + if con_name is None: + uniques[(schema, table_name)] = default() + continue - return [ - {"name": name, "column_names": [uc["cols"][i] for i in uc["key"]]} - for name, uc in uniques.items() - ] + uniques[(schema, table_name)].append( + { + "column_names": cols, + "name": con_name, + } + ) + return uniques.items() @reflection.cache def get_table_comment(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_table_comment( + connection, + schema, + [table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - COMMENT_SQL = """ - SELECT - pgd.description as table_comment - FROM - pg_catalog.pg_description pgd - WHERE - pgd.objsubid = 0 AND - pgd.objoid = :table_oid - """ + @lru_cache() + def _comment_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_description.c.description, + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_description, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_description.c.objoid, + pg_catalog.pg_description.c.objsubid == 0, + ), + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query - c = connection.execute( - sql.text(COMMENT_SQL), dict(table_oid=table_oid) + def get_multi_table_comment( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._comment_query(schema, has_filter_names, scope, kind) + result = connection.execute(query, params) + + default = ReflectionDefaults.table_comment + return ( + ( + (schema, table), + {"text": comment} if comment is not None else default(), + ) + for table, comment in result ) - return {"text": c.scalar()} @reflection.cache def get_check_constraints(self, connection, table_name, schema=None, **kw): - table_oid = self.get_table_oid( - connection, table_name, schema, info_cache=kw.get("info_cache") + data = self.get_multi_check_constraints( + connection, + schema, + [table_name], + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + **kw, ) + return self._value_or_raise(data, table_name, schema) - CHECK_SQL = """ - SELECT - cons.conname as name, - pg_get_constraintdef(cons.oid) as src - FROM - pg_catalog.pg_constraint cons - WHERE - cons.conrelid = :table_oid AND - cons.contype = 'c' - """ - - c = connection.execute(sql.text(CHECK_SQL), dict(table_oid=table_oid)) + @lru_cache() + def _check_constraint_query(self, schema, has_filter_names, scope, kind): + relkinds = self._kind_to_relkinds(kind) + query = ( + select( + pg_catalog.pg_class.c.relname, + pg_catalog.pg_constraint.c.conname, + sql.case( + ( + pg_catalog.pg_constraint.c.oid.is_not(None), + pg_catalog.pg_get_constraintdef( + pg_catalog.pg_constraint.c.oid + ), + ), + else_=None, + ), + ) + .select_from(pg_catalog.pg_class) + .outerjoin( + pg_catalog.pg_constraint, + sql.and_( + pg_catalog.pg_class.c.oid + == pg_catalog.pg_constraint.c.conrelid, + pg_catalog.pg_constraint.c.contype == "c", + ), + ) + .where(self._pg_class_relkind_condition(relkinds)) + ) + query = self._pg_class_filter_scope_schema(query, schema, scope) + if has_filter_names: + query = query.where( + pg_catalog.pg_class.c.relname.in_(bindparam("filter_names")) + ) + return query - ret = [] - for name, src in c: + def get_multi_check_constraints( + self, connection, schema, filter_names, scope, kind, **kw + ): + has_filter_names, params = self._prepare_filter_names(filter_names) + query = self._check_constraint_query( + schema, has_filter_names, scope, kind + ) + result = connection.execute(query, params) + + check_constraints = defaultdict(list) + default = ReflectionDefaults.check_constraints + for table_name, check_name, src in result: + # only two cases for check_name and src: both null or both defined + if check_name is None and src is None: + check_constraints[(schema, table_name)] = default() + continue # samples: # "CHECK (((a > 1) AND (a < 5)))" # "CHECK (((a = 1) OR ((a > 2) AND (a < 5))))" @@ -4424,84 +4258,118 @@ class PGDialect(default.DefaultDialect): sqltext = re.compile( r"^[\s\n]*\((.+)\)[\s\n]*$", flags=re.DOTALL ).sub(r"\1", m.group(1)) - entry = {"name": name, "sqltext": sqltext} + entry = {"name": check_name, "sqltext": sqltext} if m and m.group(2): entry["dialect_options"] = {"not_valid": True} - ret.append(entry) - return ret - - def _load_enums(self, connection, schema=None): - schema = schema or self.default_schema_name - if not self.supports_native_enum: - return {} - - # Load data types for enums: - SQL_ENUMS = """ - SELECT t.typname as "name", - -- no enum defaults in 8.4 at least - -- t.typdefault as "default", - pg_catalog.pg_type_is_visible(t.oid) as "visible", - n.nspname as "schema", - e.enumlabel as "label" - FROM pg_catalog.pg_type t - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace - LEFT JOIN pg_catalog.pg_enum e ON t.oid = e.enumtypid - WHERE t.typtype = 'e' - """ + check_constraints[(schema, table_name)].append(entry) + return check_constraints.items() - if schema != "*": - SQL_ENUMS += "AND n.nspname = :schema " + @lru_cache() + def _enum_query(self, schema): + lbl_sq = ( + select( + pg_catalog.pg_enum.c.enumtypid, pg_catalog.pg_enum.c.enumlabel + ) + .order_by( + pg_catalog.pg_enum.c.enumtypid, + pg_catalog.pg_enum.c.enumsortorder, + ) + .subquery("lbl") + ) - # e.oid gives us label order within an enum - SQL_ENUMS += 'ORDER BY "schema", "name", e.oid' + lbl_agg_sq = ( + select( + lbl_sq.c.enumtypid, + sql.func.array_agg(lbl_sq.c.enumlabel).label("labels"), + ) + .group_by(lbl_sq.c.enumtypid) + .subquery("lbl_agg") + ) - s = sql.text(SQL_ENUMS).columns( - attname=sqltypes.Unicode, label=sqltypes.Unicode + query = ( + select( + pg_catalog.pg_type.c.typname.label("name"), + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label( + "visible" + ), + pg_catalog.pg_namespace.c.nspname.label("schema"), + lbl_agg_sq.c.labels.label("labels"), + ) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, + ) + .outerjoin( + lbl_agg_sq, pg_catalog.pg_type.c.oid == lbl_agg_sq.c.enumtypid + ) + .where(pg_catalog.pg_type.c.typtype == "e") + .order_by( + pg_catalog.pg_namespace.c.nspname, pg_catalog.pg_type.c.typname + ) ) - if schema != "*": - s = s.bindparams(schema=schema) + if schema is None: + query = query.where( + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid), + # ignore pg_catalog schema + pg_catalog.pg_namespace.c.nspname != "pg_catalog", + ) + elif schema != "*": + query = query.where(pg_catalog.pg_namespace.c.nspname == schema) + return query + + @reflection.cache + def _load_enums(self, connection, schema=None, **kw): + if not self.supports_native_enum: + return [] - c = connection.execute(s) + result = connection.execute(self._enum_query(schema)) enums = [] - enum_by_name = {} - for enum in c.fetchall(): - key = (enum.schema, enum.name) - if key in enum_by_name: - enum_by_name[key]["labels"].append(enum.label) - else: - enum_by_name[key] = enum_rec = { - "name": enum.name, - "schema": enum.schema, - "visible": enum.visible, - "labels": [], + for name, visible, schema, labels in result: + enums.append( + { + "name": name, + "schema": schema, + "visible": visible, + "labels": [] if labels is None else labels, } - if enum.label is not None: - enum_rec["labels"].append(enum.label) - enums.append(enum_rec) + ) return enums - def _load_domains(self, connection): - # Load data types for domains: - SQL_DOMAINS = """ - SELECT t.typname as "name", - pg_catalog.format_type(t.typbasetype, t.typtypmod) as "attype", - not t.typnotnull as "nullable", - t.typdefault as "default", - pg_catalog.pg_type_is_visible(t.oid) as "visible", - n.nspname as "schema" - FROM pg_catalog.pg_type t - LEFT JOIN pg_catalog.pg_namespace n ON n.oid = t.typnamespace - WHERE t.typtype = 'd' - """ + @util.memoized_property + def _domain_query(self): + return ( + select( + pg_catalog.pg_type.c.typname.label("name"), + pg_catalog.format_type( + pg_catalog.pg_type.c.typbasetype, + pg_catalog.pg_type.c.typtypmod, + ).label("attype"), + (~pg_catalog.pg_type.c.typnotnull).label("nullable"), + pg_catalog.pg_type.c.typdefault.label("default"), + pg_catalog.pg_type_is_visible(pg_catalog.pg_type.c.oid).label( + "visible" + ), + pg_catalog.pg_namespace.c.nspname.label("schema"), + ) + .join( + pg_catalog.pg_namespace, + pg_catalog.pg_namespace.c.oid + == pg_catalog.pg_type.c.typnamespace, + ) + .where(pg_catalog.pg_type.c.typtype == "d") + ) - s = sql.text(SQL_DOMAINS) - c = connection.execution_options(future_result=True).execute(s) + @reflection.cache + def _load_domains(self, connection, **kw): + # Load data types for domains: + result = connection.execute(self._domain_query) domains = {} - for domain in c.mappings(): + for domain in result.mappings(): domain = domain # strip (30) from character varying(30) attype = re.search(r"([^\(]+)", domain["attype"]).group(1) diff --git a/lib/sqlalchemy/dialects/postgresql/pg8000.py b/lib/sqlalchemy/dialects/postgresql/pg8000.py index 6cb97ece4..ce9a3bb6c 100644 --- a/lib/sqlalchemy/dialects/postgresql/pg8000.py +++ b/lib/sqlalchemy/dialects/postgresql/pg8000.py @@ -107,6 +107,8 @@ from .base import PGIdentifierPreparer from .json import JSON from .json import JSONB from .json import JSONPathType +from .pg_catalog import _SpaceVector +from .pg_catalog import OIDVECTOR from ... import exc from ... import util from ...engine import processors @@ -245,6 +247,10 @@ class _PGARRAY(PGARRAY): render_bind_cast = True +class _PGOIDVECTOR(_SpaceVector, OIDVECTOR): + pass + + _server_side_id = util.counter() @@ -376,6 +382,7 @@ class PGDialect_pg8000(PGDialect): sqltypes.BigInteger: _PGBigInteger, sqltypes.Enum: _PGEnum, sqltypes.ARRAY: _PGARRAY, + OIDVECTOR: _PGOIDVECTOR, }, ) diff --git a/lib/sqlalchemy/dialects/postgresql/pg_catalog.py b/lib/sqlalchemy/dialects/postgresql/pg_catalog.py new file mode 100644 index 000000000..a77e7ccf6 --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/pg_catalog.py @@ -0,0 +1,292 @@ +# postgresql/pg_catalog.py +# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +from .array import ARRAY +from .types import OID +from .types import REGCLASS +from ... import Column +from ... import func +from ... import MetaData +from ... import Table +from ...types import BigInteger +from ...types import Boolean +from ...types import CHAR +from ...types import Float +from ...types import Integer +from ...types import SmallInteger +from ...types import String +from ...types import Text +from ...types import TypeDecorator + + +# types +class NAME(TypeDecorator): + impl = String(64, collation="C") + cache_ok = True + + +class PG_NODE_TREE(TypeDecorator): + impl = Text(collation="C") + cache_ok = True + + +class INT2VECTOR(TypeDecorator): + impl = ARRAY(SmallInteger) + cache_ok = True + + +class OIDVECTOR(TypeDecorator): + impl = ARRAY(OID) + cache_ok = True + + +class _SpaceVector: + def result_processor(self, dialect, coltype): + def process(value): + if value is None: + return value + return [int(p) for p in value.split(" ")] + + return process + + +REGPROC = REGCLASS # seems an alias + +# functions +_pg_cat = func.pg_catalog +quote_ident = _pg_cat.quote_ident +pg_table_is_visible = _pg_cat.pg_table_is_visible +pg_type_is_visible = _pg_cat.pg_type_is_visible +pg_get_viewdef = _pg_cat.pg_get_viewdef +pg_get_serial_sequence = _pg_cat.pg_get_serial_sequence +format_type = _pg_cat.format_type +pg_get_expr = _pg_cat.pg_get_expr +pg_get_constraintdef = _pg_cat.pg_get_constraintdef + +# constants +RELKINDS_TABLE_NO_FOREIGN = ("r", "p") +RELKINDS_TABLE = RELKINDS_TABLE_NO_FOREIGN + ("f",) +RELKINDS_VIEW = ("v",) +RELKINDS_MAT_VIEW = ("m",) +RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW + +# tables +pg_catalog_meta = MetaData() + +pg_namespace = Table( + "pg_namespace", + pg_catalog_meta, + Column("oid", OID), + Column("nspname", NAME), + Column("nspowner", OID), + schema="pg_catalog", +) + +pg_class = Table( + "pg_class", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("relname", NAME), + Column("relnamespace", OID), + Column("reltype", OID), + Column("reloftype", OID), + Column("relowner", OID), + Column("relam", OID), + Column("relfilenode", OID), + Column("reltablespace", OID), + Column("relpages", Integer), + Column("reltuples", Float), + Column("relallvisible", Integer, info={"server_version": (9, 2)}), + Column("reltoastrelid", OID), + Column("relhasindex", Boolean), + Column("relisshared", Boolean), + Column("relpersistence", CHAR, info={"server_version": (9, 1)}), + Column("relkind", CHAR), + Column("relnatts", SmallInteger), + Column("relchecks", SmallInteger), + Column("relhasrules", Boolean), + Column("relhastriggers", Boolean), + Column("relhassubclass", Boolean), + Column("relrowsecurity", Boolean), + Column("relforcerowsecurity", Boolean, info={"server_version": (9, 5)}), + Column("relispopulated", Boolean, info={"server_version": (9, 3)}), + Column("relreplident", CHAR, info={"server_version": (9, 4)}), + Column("relispartition", Boolean, info={"server_version": (10,)}), + Column("relrewrite", OID, info={"server_version": (11,)}), + Column("reloptions", ARRAY(Text)), + schema="pg_catalog", +) + +pg_type = Table( + "pg_type", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("typname", NAME), + Column("typnamespace", OID), + Column("typowner", OID), + Column("typlen", SmallInteger), + Column("typbyval", Boolean), + Column("typtype", CHAR), + Column("typcategory", CHAR), + Column("typispreferred", Boolean), + Column("typisdefined", Boolean), + Column("typdelim", CHAR), + Column("typrelid", OID), + Column("typelem", OID), + Column("typarray", OID), + Column("typinput", REGPROC), + Column("typoutput", REGPROC), + Column("typreceive", REGPROC), + Column("typsend", REGPROC), + Column("typmodin", REGPROC), + Column("typmodout", REGPROC), + Column("typanalyze", REGPROC), + Column("typalign", CHAR), + Column("typstorage", CHAR), + Column("typnotnull", Boolean), + Column("typbasetype", OID), + Column("typtypmod", Integer), + Column("typndims", Integer), + Column("typcollation", OID, info={"server_version": (9, 1)}), + Column("typdefault", Text), + schema="pg_catalog", +) + +pg_index = Table( + "pg_index", + pg_catalog_meta, + Column("indexrelid", OID), + Column("indrelid", OID), + Column("indnatts", SmallInteger), + Column("indnkeyatts", SmallInteger, info={"server_version": (11,)}), + Column("indisunique", Boolean), + Column("indisprimary", Boolean), + Column("indisexclusion", Boolean, info={"server_version": (9, 1)}), + Column("indimmediate", Boolean), + Column("indisclustered", Boolean), + Column("indisvalid", Boolean), + Column("indcheckxmin", Boolean), + Column("indisready", Boolean), + Column("indislive", Boolean, info={"server_version": (9, 3)}), # 9.3 + Column("indisreplident", Boolean), + Column("indkey", INT2VECTOR), + Column("indcollation", OIDVECTOR, info={"server_version": (9, 1)}), # 9.1 + Column("indclass", OIDVECTOR), + Column("indoption", INT2VECTOR), + Column("indexprs", PG_NODE_TREE), + Column("indpred", PG_NODE_TREE), + schema="pg_catalog", +) + +pg_attribute = Table( + "pg_attribute", + pg_catalog_meta, + Column("attrelid", OID), + Column("attname", NAME), + Column("atttypid", OID), + Column("attstattarget", Integer), + Column("attlen", SmallInteger), + Column("attnum", SmallInteger), + Column("attndims", Integer), + Column("attcacheoff", Integer), + Column("atttypmod", Integer), + Column("attbyval", Boolean), + Column("attstorage", CHAR), + Column("attalign", CHAR), + Column("attnotnull", Boolean), + Column("atthasdef", Boolean), + Column("atthasmissing", Boolean, info={"server_version": (11,)}), + Column("attidentity", CHAR, info={"server_version": (10,)}), + Column("attgenerated", CHAR, info={"server_version": (12,)}), + Column("attisdropped", Boolean), + Column("attislocal", Boolean), + Column("attinhcount", Integer), + Column("attcollation", OID, info={"server_version": (9, 1)}), + schema="pg_catalog", +) + +pg_constraint = Table( + "pg_constraint", + pg_catalog_meta, + Column("oid", OID), # 9.3 + Column("conname", NAME), + Column("connamespace", OID), + Column("contype", CHAR), + Column("condeferrable", Boolean), + Column("condeferred", Boolean), + Column("convalidated", Boolean, info={"server_version": (9, 1)}), + Column("conrelid", OID), + Column("contypid", OID), + Column("conindid", OID), + Column("conparentid", OID, info={"server_version": (11,)}), + Column("confrelid", OID), + Column("confupdtype", CHAR), + Column("confdeltype", CHAR), + Column("confmatchtype", CHAR), + Column("conislocal", Boolean), + Column("coninhcount", Integer), + Column("connoinherit", Boolean, info={"server_version": (9, 2)}), + Column("conkey", ARRAY(SmallInteger)), + Column("confkey", ARRAY(SmallInteger)), + schema="pg_catalog", +) + +pg_sequence = Table( + "pg_sequence", + pg_catalog_meta, + Column("seqrelid", OID), + Column("seqtypid", OID), + Column("seqstart", BigInteger), + Column("seqincrement", BigInteger), + Column("seqmax", BigInteger), + Column("seqmin", BigInteger), + Column("seqcache", BigInteger), + Column("seqcycle", Boolean), + schema="pg_catalog", + info={"server_version": (10,)}, +) + +pg_attrdef = Table( + "pg_attrdef", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("adrelid", OID), + Column("adnum", SmallInteger), + Column("adbin", PG_NODE_TREE), + schema="pg_catalog", +) + +pg_description = Table( + "pg_description", + pg_catalog_meta, + Column("objoid", OID), + Column("classoid", OID), + Column("objsubid", Integer), + Column("description", Text(collation="C")), + schema="pg_catalog", +) + +pg_enum = Table( + "pg_enum", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("enumtypid", OID), + Column("enumsortorder", Float(), info={"server_version": (9, 1)}), + Column("enumlabel", NAME), + schema="pg_catalog", +) + +pg_am = Table( + "pg_am", + pg_catalog_meta, + Column("oid", OID, info={"server_version": (9, 3)}), + Column("amname", NAME), + Column("amhandler", REGPROC, info={"server_version": (9, 6)}), + Column("amtype", CHAR, info={"server_version": (9, 6)}), + schema="pg_catalog", +) diff --git a/lib/sqlalchemy/dialects/postgresql/types.py b/lib/sqlalchemy/dialects/postgresql/types.py new file mode 100644 index 000000000..55735953b --- /dev/null +++ b/lib/sqlalchemy/dialects/postgresql/types.py @@ -0,0 +1,485 @@ +# Copyright (C) 2013-2022 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + +import datetime as dt +from typing import Any + +from ... import schema +from ... import util +from ...sql import sqltypes +from ...sql.ddl import InvokeDDLBase + + +_DECIMAL_TYPES = (1231, 1700) +_FLOAT_TYPES = (700, 701, 1021, 1022) +_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016) + + +class PGUuid(sqltypes.UUID): + render_bind_cast = True + render_literal_cast = True + + +class BYTEA(sqltypes.LargeBinary[bytes]): + __visit_name__ = "BYTEA" + + +class INET(sqltypes.TypeEngine[str]): + __visit_name__ = "INET" + + +PGInet = INET + + +class CIDR(sqltypes.TypeEngine[str]): + __visit_name__ = "CIDR" + + +PGCidr = CIDR + + +class MACADDR(sqltypes.TypeEngine[str]): + __visit_name__ = "MACADDR" + + +PGMacAddr = MACADDR + + +class MONEY(sqltypes.TypeEngine[str]): + + r"""Provide the PostgreSQL MONEY type. + + Depending on driver, result rows using this type may return a + string value which includes currency symbols. + + For this reason, it may be preferable to provide conversion to a + numerically-based currency datatype using :class:`_types.TypeDecorator`:: + + import re + import decimal + from sqlalchemy import TypeDecorator + + class NumericMoney(TypeDecorator): + impl = MONEY + + def process_result_value(self, value: Any, dialect: Any) -> None: + if value is not None: + # adjust this for the currency and numeric + m = re.match(r"\$([\d.]+)", value) + if m: + value = decimal.Decimal(m.group(1)) + return value + + Alternatively, the conversion may be applied as a CAST using + the :meth:`_types.TypeDecorator.column_expression` method as follows:: + + import decimal + from sqlalchemy import cast + from sqlalchemy import TypeDecorator + + class NumericMoney(TypeDecorator): + impl = MONEY + + def column_expression(self, column: Any): + return cast(column, Numeric()) + + .. versionadded:: 1.2 + + """ + + __visit_name__ = "MONEY" + + +class OID(sqltypes.TypeEngine[int]): + + """Provide the PostgreSQL OID type. + + .. versionadded:: 0.9.5 + + """ + + __visit_name__ = "OID" + + +class REGCLASS(sqltypes.TypeEngine[str]): + + """Provide the PostgreSQL REGCLASS type. + + .. versionadded:: 1.2.7 + + """ + + __visit_name__ = "REGCLASS" + + +class TIMESTAMP(sqltypes.TIMESTAMP): + def __init__(self, timezone=False, precision=None): + super(TIMESTAMP, self).__init__(timezone=timezone) + self.precision = precision + + +class TIME(sqltypes.TIME): + def __init__(self, timezone=False, precision=None): + super(TIME, self).__init__(timezone=timezone) + self.precision = precision + + +class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval): + + """PostgreSQL INTERVAL type.""" + + __visit_name__ = "INTERVAL" + native = True + + def __init__(self, precision=None, fields=None): + """Construct an INTERVAL. + + :param precision: optional integer precision value + :param fields: string fields specifier. allows storage of fields + to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``, + etc. + + .. versionadded:: 1.2 + + """ + self.precision = precision + self.fields = fields + + @classmethod + def adapt_emulated_to_native(cls, interval, **kw): + return INTERVAL(precision=interval.second_precision) + + @property + def _type_affinity(self): + return sqltypes.Interval + + def as_generic(self, allow_nulltype=False): + return sqltypes.Interval(native=True, second_precision=self.precision) + + @property + def python_type(self): + return dt.timedelta + + +PGInterval = INTERVAL + + +class BIT(sqltypes.TypeEngine[int]): + __visit_name__ = "BIT" + + def __init__(self, length=None, varying=False): + if not varying: + # BIT without VARYING defaults to length 1 + self.length = length or 1 + else: + # but BIT VARYING can be unlimited-length, so no default + self.length = length + self.varying = varying + + +PGBit = BIT + + +class TSVECTOR(sqltypes.TypeEngine[Any]): + + """The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL + text search type TSVECTOR. + + It can be used to do full text queries on natural language + documents. + + .. versionadded:: 0.9.0 + + .. seealso:: + + :ref:`postgresql_match` + + """ + + __visit_name__ = "TSVECTOR" + + +class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum): + + """PostgreSQL ENUM type. + + This is a subclass of :class:`_types.Enum` which includes + support for PG's ``CREATE TYPE`` and ``DROP TYPE``. + + When the builtin type :class:`_types.Enum` is used and the + :paramref:`.Enum.native_enum` flag is left at its default of + True, the PostgreSQL backend will use a :class:`_postgresql.ENUM` + type as the implementation, so the special create/drop rules + will be used. + + The create/drop behavior of ENUM is necessarily intricate, due to the + awkward relationship the ENUM type has in relationship to the + parent table, in that it may be "owned" by just a single table, or + may be shared among many tables. + + When using :class:`_types.Enum` or :class:`_postgresql.ENUM` + in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted + corresponding to when the :meth:`_schema.Table.create` and + :meth:`_schema.Table.drop` + methods are called:: + + table = Table('sometable', metadata, + Column('some_enum', ENUM('a', 'b', 'c', name='myenum')) + ) + + table.create(engine) # will emit CREATE ENUM and CREATE TABLE + table.drop(engine) # will emit DROP TABLE and DROP ENUM + + To use a common enumerated type between multiple tables, the best + practice is to declare the :class:`_types.Enum` or + :class:`_postgresql.ENUM` independently, and associate it with the + :class:`_schema.MetaData` object itself:: + + my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata) + + t1 = Table('sometable_one', metadata, + Column('some_enum', myenum) + ) + + t2 = Table('sometable_two', metadata, + Column('some_enum', myenum) + ) + + When this pattern is used, care must still be taken at the level + of individual table creates. Emitting CREATE TABLE without also + specifying ``checkfirst=True`` will still cause issues:: + + t1.create(engine) # will fail: no such type 'myenum' + + If we specify ``checkfirst=True``, the individual table-level create + operation will check for the ``ENUM`` and create if not exists:: + + # will check if enum exists, and emit CREATE TYPE if not + t1.create(engine, checkfirst=True) + + When using a metadata-level ENUM type, the type will always be created + and dropped if either the metadata-wide create/drop is called:: + + metadata.create_all(engine) # will emit CREATE TYPE + metadata.drop_all(engine) # will emit DROP TYPE + + The type can also be created and dropped directly:: + + my_enum.create(engine) + my_enum.drop(engine) + + .. versionchanged:: 1.0.0 The PostgreSQL :class:`_postgresql.ENUM` type + now behaves more strictly with regards to CREATE/DROP. A metadata-level + ENUM type will only be created and dropped at the metadata level, + not the table level, with the exception of + ``table.create(checkfirst=True)``. + The ``table.drop()`` call will now emit a DROP TYPE for a table-level + enumerated type. + + """ + + native_enum = True + + def __init__(self, *enums, **kw): + """Construct an :class:`_postgresql.ENUM`. + + Arguments are the same as that of + :class:`_types.Enum`, but also including + the following parameters. + + :param create_type: Defaults to True. + Indicates that ``CREATE TYPE`` should be + emitted, after optionally checking for the + presence of the type, when the parent + table is being created; and additionally + that ``DROP TYPE`` is called when the table + is dropped. When ``False``, no check + will be performed and no ``CREATE TYPE`` + or ``DROP TYPE`` is emitted, unless + :meth:`~.postgresql.ENUM.create` + or :meth:`~.postgresql.ENUM.drop` + are called directly. + Setting to ``False`` is helpful + when invoking a creation scheme to a SQL file + without access to the actual database - + the :meth:`~.postgresql.ENUM.create` and + :meth:`~.postgresql.ENUM.drop` methods can + be used to emit SQL to a target bind. + + """ + native_enum = kw.pop("native_enum", None) + if native_enum is False: + util.warn( + "the native_enum flag does not apply to the " + "sqlalchemy.dialects.postgresql.ENUM datatype; this type " + "always refers to ENUM. Use sqlalchemy.types.Enum for " + "non-native enum." + ) + self.create_type = kw.pop("create_type", True) + super(ENUM, self).__init__(*enums, **kw) + + @classmethod + def adapt_emulated_to_native(cls, impl, **kw): + """Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain + :class:`.Enum`. + + """ + kw.setdefault("validate_strings", impl.validate_strings) + kw.setdefault("name", impl.name) + kw.setdefault("schema", impl.schema) + kw.setdefault("inherit_schema", impl.inherit_schema) + kw.setdefault("metadata", impl.metadata) + kw.setdefault("_create_events", False) + kw.setdefault("values_callable", impl.values_callable) + kw.setdefault("omit_aliases", impl._omit_aliases) + return cls(**kw) + + def create(self, bind=None, checkfirst=True): + """Emit ``CREATE TYPE`` for this + :class:`_postgresql.ENUM`. + + If the underlying dialect does not support + PostgreSQL CREATE TYPE, no action is taken. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type does not exist already before + creating. + + """ + if not bind.dialect.supports_native_enum: + return + + bind._run_ddl_visitor(self.EnumGenerator, self, checkfirst=checkfirst) + + def drop(self, bind=None, checkfirst=True): + """Emit ``DROP TYPE`` for this + :class:`_postgresql.ENUM`. + + If the underlying dialect does not support + PostgreSQL DROP TYPE, no action is taken. + + :param bind: a connectable :class:`_engine.Engine`, + :class:`_engine.Connection`, or similar object to emit + SQL. + :param checkfirst: if ``True``, a query against + the PG catalog will be first performed to see + if the type actually exists before dropping. + + """ + if not bind.dialect.supports_native_enum: + return + + bind._run_ddl_visitor(self.EnumDropper, self, checkfirst=checkfirst) + + class EnumGenerator(InvokeDDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super(ENUM.EnumGenerator, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_create_enum(self, enum): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(enum) + + return not self.connection.dialect.has_type( + self.connection, enum.name, schema=effective_schema + ) + + def visit_enum(self, enum): + if not self._can_create_enum(enum): + return + + self.connection.execute(CreateEnumType(enum)) + + class EnumDropper(InvokeDDLBase): + def __init__(self, dialect, connection, checkfirst=False, **kwargs): + super(ENUM.EnumDropper, self).__init__(connection, **kwargs) + self.checkfirst = checkfirst + + def _can_drop_enum(self, enum): + if not self.checkfirst: + return True + + effective_schema = self.connection.schema_for_object(enum) + + return self.connection.dialect.has_type( + self.connection, enum.name, schema=effective_schema + ) + + def visit_enum(self, enum): + if not self._can_drop_enum(enum): + return + + self.connection.execute(DropEnumType(enum)) + + def get_dbapi_type(self, dbapi): + """dont return dbapi.STRING for ENUM in PostgreSQL, since that's + a different type""" + + return None + + def _check_for_name_in_memos(self, checkfirst, kw): + """Look in the 'ddl runner' for 'memos', then + note our name in that collection. + + This to ensure a particular named enum is operated + upon only once within any kind of create/drop + sequence without relying upon "checkfirst". + + """ + if not self.create_type: + return True + if "_ddl_runner" in kw: + ddl_runner = kw["_ddl_runner"] + if "_pg_enums" in ddl_runner.memo: + pg_enums = ddl_runner.memo["_pg_enums"] + else: + pg_enums = ddl_runner.memo["_pg_enums"] = set() + present = (self.schema, self.name) in pg_enums + pg_enums.add((self.schema, self.name)) + return present + else: + return False + + def _on_table_create(self, target, bind, checkfirst=False, **kw): + if ( + checkfirst + or ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + ) + ) and not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_table_drop(self, target, bind, checkfirst=False, **kw): + if ( + not self.metadata + and not kw.get("_is_metadata_operation", False) + and not self._check_for_name_in_memos(checkfirst, kw) + ): + self.drop(bind=bind, checkfirst=checkfirst) + + def _on_metadata_create(self, target, bind, checkfirst=False, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.create(bind=bind, checkfirst=checkfirst) + + def _on_metadata_drop(self, target, bind, checkfirst=False, **kw): + if not self._check_for_name_in_memos(checkfirst, kw): + self.drop(bind=bind, checkfirst=checkfirst) + + +class CreateEnumType(schema._CreateDropBase): + __visit_name__ = "create_enum_type" + + +class DropEnumType(schema._CreateDropBase): + __visit_name__ = "drop_enum_type" diff --git a/lib/sqlalchemy/dialects/sqlite/base.py b/lib/sqlalchemy/dialects/sqlite/base.py index fdcd1340b..22f003e38 100644 --- a/lib/sqlalchemy/dialects/sqlite/base.py +++ b/lib/sqlalchemy/dialects/sqlite/base.py @@ -867,6 +867,7 @@ from ... import util from ...engine import default from ...engine import processors from ...engine import reflection +from ...engine.reflection import ReflectionDefaults from ...sql import coercions from ...sql import ColumnElement from ...sql import compiler @@ -2053,28 +2054,27 @@ class SQLiteDialect(default.DefaultDialect): return [db[1] for db in dl if db[1] != "temp"] - @reflection.cache - def get_table_names(self, connection, schema=None, **kw): + def _format_schema(self, schema, table_name): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = "%s.sqlite_master" % qschema + name = f"{qschema}.{table_name}" else: - master = "sqlite_master" - s = ("SELECT name FROM %s " "WHERE type='table' ORDER BY name") % ( - master, - ) - rs = connection.exec_driver_sql(s) - return [row[0] for row in rs] + name = table_name + return name @reflection.cache - def get_temp_table_names(self, connection, **kw): - s = ( - "SELECT name FROM sqlite_temp_master " - "WHERE type='table' ORDER BY name " - ) - rs = connection.exec_driver_sql(s) + def get_table_names(self, connection, schema=None, **kw): + main = self._format_schema(schema, "sqlite_master") + s = f"SELECT name FROM {main} WHERE type='table' ORDER BY name" + names = connection.exec_driver_sql(s).scalars().all() + return names - return [row[0] for row in rs] + @reflection.cache + def get_temp_table_names(self, connection, **kw): + main = "sqlite_temp_master" + s = f"SELECT name FROM {main} WHERE type='table' ORDER BY name" + names = connection.exec_driver_sql(s).scalars().all() + return names @reflection.cache def get_temp_view_names(self, connection, **kw): @@ -2082,11 +2082,11 @@ class SQLiteDialect(default.DefaultDialect): "SELECT name FROM sqlite_temp_master " "WHERE type='view' ORDER BY name " ) - rs = connection.exec_driver_sql(s) - - return [row[0] for row in rs] + names = connection.exec_driver_sql(s).scalars().all() + return names - def has_table(self, connection, table_name, schema=None): + @reflection.cache + def has_table(self, connection, table_name, schema=None, **kw): self._ensure_has_table_connection(connection) info = self._get_table_pragma( @@ -2099,23 +2099,16 @@ class SQLiteDialect(default.DefaultDialect): @reflection.cache def get_view_names(self, connection, schema=None, **kw): - if schema is not None: - qschema = self.identifier_preparer.quote_identifier(schema) - master = "%s.sqlite_master" % qschema - else: - master = "sqlite_master" - s = ("SELECT name FROM %s " "WHERE type='view' ORDER BY name") % ( - master, - ) - rs = connection.exec_driver_sql(s) - - return [row[0] for row in rs] + main = self._format_schema(schema, "sqlite_master") + s = f"SELECT name FROM {main} WHERE type='view' ORDER BY name" + names = connection.exec_driver_sql(s).scalars().all() + return names @reflection.cache def get_view_definition(self, connection, view_name, schema=None, **kw): if schema is not None: qschema = self.identifier_preparer.quote_identifier(schema) - master = "%s.sqlite_master" % qschema + master = f"{qschema}.sqlite_master" s = ("SELECT sql FROM %s WHERE name = ? AND type='view'") % ( master, ) @@ -2140,6 +2133,10 @@ class SQLiteDialect(default.DefaultDialect): result = rs.fetchall() if result: return result[0].sql + else: + raise exc.NoSuchTableError( + f"{schema}.{view_name}" if schema else view_name + ) @reflection.cache def get_columns(self, connection, table_name, schema=None, **kw): @@ -2186,7 +2183,14 @@ class SQLiteDialect(default.DefaultDialect): tablesql, ) ) - return columns + if columns: + return columns + elif not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) + else: + return ReflectionDefaults.columns() def _get_column_info( self, @@ -2216,7 +2220,6 @@ class SQLiteDialect(default.DefaultDialect): "type": coltype, "nullable": nullable, "default": default, - "autoincrement": "auto", "primary_key": primary_key, } if generated: @@ -2295,13 +2298,16 @@ class SQLiteDialect(default.DefaultDialect): constraint_name = result.group(1) if result else None cols = self.get_columns(connection, table_name, schema, **kw) + # consider only pk columns. This also avoids sorting the cached + # value returned by get_columns + cols = [col for col in cols if col.get("primary_key", 0) > 0] cols.sort(key=lambda col: col.get("primary_key")) - pkeys = [] - for col in cols: - if col["primary_key"]: - pkeys.append(col["name"]) + pkeys = [col["name"] for col in cols] - return {"constrained_columns": pkeys, "name": constraint_name} + if pkeys: + return {"constrained_columns": pkeys, "name": constraint_name} + else: + return ReflectionDefaults.pk_constraint() @reflection.cache def get_foreign_keys(self, connection, table_name, schema=None, **kw): @@ -2321,12 +2327,14 @@ class SQLiteDialect(default.DefaultDialect): # original DDL. The referred columns of the foreign key # constraint are therefore the primary key of the referred # table. - referred_pk = self.get_pk_constraint( - connection, rtbl, schema=schema, **kw - ) - # note that if table doesn't exist, we still get back a record, - # just it has no columns in it - referred_columns = referred_pk["constrained_columns"] + try: + referred_pk = self.get_pk_constraint( + connection, rtbl, schema=schema, **kw + ) + referred_columns = referred_pk["constrained_columns"] + except exc.NoSuchTableError: + # ignore not existing parents + referred_columns = [] else: # note we use this list only if this is the first column # in the constraint. for subsequent columns we ignore the @@ -2378,11 +2386,11 @@ class SQLiteDialect(default.DefaultDialect): ) table_data = self._get_table_sql(connection, table_name, schema=schema) - if table_data is None: - # system tables, etc. - return [] def parse_fks(): + if table_data is None: + # system tables, etc. + return FK_PATTERN = ( r"(?:CONSTRAINT (\w+) +)?" r"FOREIGN KEY *\( *(.+?) *\) +" @@ -2453,7 +2461,10 @@ class SQLiteDialect(default.DefaultDialect): # use them as is as it's extremely difficult to parse inline # constraints fkeys.extend(keys_by_signature.values()) - return fkeys + if fkeys: + return fkeys + else: + return ReflectionDefaults.foreign_keys() def _find_cols_in_sig(self, sig): for match in re.finditer(r'(?:"(.+?)")|([a-z0-9_]+)', sig, re.I): @@ -2480,12 +2491,11 @@ class SQLiteDialect(default.DefaultDialect): table_data = self._get_table_sql( connection, table_name, schema=schema, **kw ) - if not table_data: - return [] - unique_constraints = [] def parse_uqs(): + if table_data is None: + return UNIQUE_PATTERN = r'(?:CONSTRAINT "?(.+?)"? +)?UNIQUE *\((.+?)\)' INLINE_UNIQUE_PATTERN = ( r'(?:(".+?")|(?:[\[`])?([a-z0-9_]+)(?:[\]`])?) ' @@ -2513,15 +2523,16 @@ class SQLiteDialect(default.DefaultDialect): unique_constraints.append(parsed_constraint) # NOTE: auto_index_by_sig might not be empty here, # the PRIMARY KEY may have an entry. - return unique_constraints + if unique_constraints: + return unique_constraints + else: + return ReflectionDefaults.unique_constraints() @reflection.cache def get_check_constraints(self, connection, table_name, schema=None, **kw): table_data = self._get_table_sql( connection, table_name, schema=schema, **kw ) - if not table_data: - return [] CHECK_PATTERN = r"(?:CONSTRAINT (.+) +)?" r"CHECK *\( *(.+) *\),? *" check_constraints = [] @@ -2531,7 +2542,7 @@ class SQLiteDialect(default.DefaultDialect): # necessarily makes assumptions as to how the CREATE TABLE # was emitted. - for match in re.finditer(CHECK_PATTERN, table_data, re.I): + for match in re.finditer(CHECK_PATTERN, table_data or "", re.I): name = match.group(1) if name: @@ -2539,7 +2550,10 @@ class SQLiteDialect(default.DefaultDialect): check_constraints.append({"sqltext": match.group(2), "name": name}) - return check_constraints + if check_constraints: + return check_constraints + else: + return ReflectionDefaults.check_constraints() @reflection.cache def get_indexes(self, connection, table_name, schema=None, **kw): @@ -2561,7 +2575,7 @@ class SQLiteDialect(default.DefaultDialect): # loop thru unique indexes to get the column names. for idx in list(indexes): pragma_index = self._get_table_pragma( - connection, "index_info", idx["name"] + connection, "index_info", idx["name"], schema=schema ) for row in pragma_index: @@ -2574,7 +2588,23 @@ class SQLiteDialect(default.DefaultDialect): break else: idx["column_names"].append(row[2]) - return indexes + indexes.sort(key=lambda d: d["name"] or "~") # sort None as last + if indexes: + return indexes + elif not self.has_table(connection, table_name, schema): + raise exc.NoSuchTableError( + f"{schema}.{table_name}" if schema else table_name + ) + else: + return ReflectionDefaults.indexes() + + def _is_sys_table(self, table_name): + return table_name in { + "sqlite_schema", + "sqlite_master", + "sqlite_temp_schema", + "sqlite_temp_master", + } @reflection.cache def _get_table_sql(self, connection, table_name, schema=None, **kw): @@ -2590,22 +2620,25 @@ class SQLiteDialect(default.DefaultDialect): " (SELECT * FROM %(schema)ssqlite_master UNION ALL " " SELECT * FROM %(schema)ssqlite_temp_master) " "WHERE name = ? " - "AND type = 'table'" % {"schema": schema_expr} + "AND type in ('table', 'view')" % {"schema": schema_expr} ) rs = connection.exec_driver_sql(s, (table_name,)) except exc.DBAPIError: s = ( "SELECT sql FROM %(schema)ssqlite_master " "WHERE name = ? " - "AND type = 'table'" % {"schema": schema_expr} + "AND type in ('table', 'view')" % {"schema": schema_expr} ) rs = connection.exec_driver_sql(s, (table_name,)) - return rs.scalar() + value = rs.scalar() + if value is None and not self._is_sys_table(table_name): + raise exc.NoSuchTableError(f"{schema_expr}{table_name}") + return value def _get_table_pragma(self, connection, pragma, table_name, schema=None): quote = self.identifier_preparer.quote_identifier if schema is not None: - statements = ["PRAGMA %s." % quote(schema)] + statements = [f"PRAGMA {quote(schema)}."] else: # because PRAGMA looks in all attached databases if no schema # given, need to specify "main" schema, however since we want @@ -2615,7 +2648,7 @@ class SQLiteDialect(default.DefaultDialect): qtable = quote(table_name) for statement in statements: - statement = "%s%s(%s)" % (statement, pragma, qtable) + statement = f"{statement}{pragma}({qtable})" cursor = connection.exec_driver_sql(statement) if not cursor._soft_closed: # work around SQLite issue whereby cursor.description diff --git a/lib/sqlalchemy/engine/__init__.py b/lib/sqlalchemy/engine/__init__.py index afba17075..77c2fea40 100644 --- a/lib/sqlalchemy/engine/__init__.py +++ b/lib/sqlalchemy/engine/__init__.py @@ -38,6 +38,8 @@ from .interfaces import ExecutionContext as ExecutionContext from .interfaces import TypeCompiler as TypeCompiler from .mock import create_mock_engine as create_mock_engine from .reflection import Inspector as Inspector +from .reflection import ObjectKind as ObjectKind +from .reflection import ObjectScope as ObjectScope from .result import ChunkedIteratorResult as ChunkedIteratorResult from .result import FrozenResult as FrozenResult from .result import IteratorResult as IteratorResult diff --git a/lib/sqlalchemy/engine/default.py b/lib/sqlalchemy/engine/default.py index df35e7128..40af06252 100644 --- a/lib/sqlalchemy/engine/default.py +++ b/lib/sqlalchemy/engine/default.py @@ -45,6 +45,8 @@ from .interfaces import CacheStats from .interfaces import DBAPICursor from .interfaces import Dialect from .interfaces import ExecutionContext +from .reflection import ObjectKind +from .reflection import ObjectScope from .. import event from .. import exc from .. import pool @@ -508,15 +510,22 @@ class DefaultDialect(Dialect): """ return type_api.adapt_type(typeobj, self.colspecs) - def has_index(self, connection, table_name, index_name, schema=None): - if not self.has_table(connection, table_name, schema=schema): + def has_index(self, connection, table_name, index_name, schema=None, **kw): + if not self.has_table(connection, table_name, schema=schema, **kw): return False - for idx in self.get_indexes(connection, table_name, schema=schema): + for idx in self.get_indexes( + connection, table_name, schema=schema, **kw + ): if idx["name"] == index_name: return True else: return False + def has_schema( + self, connection: Connection, schema_name: str, **kw: Any + ) -> bool: + return schema_name in self.get_schema_names(connection, **kw) + def validate_identifier(self, ident): if len(ident) > self.max_identifier_length: raise exc.IdentifierError( @@ -769,6 +778,122 @@ class DefaultDialect(Dialect): def get_driver_connection(self, connection): return connection + def _overrides_default(self, method): + return ( + getattr(type(self), method).__code__ + is not getattr(DefaultDialect, method).__code__ + ) + + def _default_multi_reflect( + self, + single_tbl_method, + connection, + kind, + schema, + filter_names, + scope, + **kw, + ): + + names_fns = [] + temp_names_fns = [] + if ObjectKind.TABLE in kind: + names_fns.append(self.get_table_names) + temp_names_fns.append(self.get_temp_table_names) + if ObjectKind.VIEW in kind: + names_fns.append(self.get_view_names) + temp_names_fns.append(self.get_temp_view_names) + if ObjectKind.MATERIALIZED_VIEW in kind: + names_fns.append(self.get_materialized_view_names) + # no temp materialized view at the moment + # temp_names_fns.append(self.get_temp_materialized_view_names) + + unreflectable = kw.pop("unreflectable", {}) + + if ( + filter_names + and scope is ObjectScope.ANY + and kind is ObjectKind.ANY + ): + # if names are given and no qualification on type of table + # (i.e. the Table(..., autoload) case), take the names as given, + # don't run names queries. If a table does not exit + # NoSuchTableError is raised and it's skipped + + # this also suits the case for mssql where we can reflect + # individual temp tables but there's no temp_names_fn + names = filter_names + else: + names = [] + name_kw = {"schema": schema, **kw} + fns = [] + if ObjectScope.DEFAULT in scope: + fns.extend(names_fns) + if ObjectScope.TEMPORARY in scope: + fns.extend(temp_names_fns) + + for fn in fns: + try: + names.extend(fn(connection, **name_kw)) + except NotImplementedError: + pass + + if filter_names: + filter_names = set(filter_names) + + # iterate over all the tables/views and call the single table method + for table in names: + if not filter_names or table in filter_names: + key = (schema, table) + try: + yield ( + key, + single_tbl_method( + connection, table, schema=schema, **kw + ), + ) + except exc.UnreflectableTableError as err: + if key not in unreflectable: + unreflectable[key] = err + except exc.NoSuchTableError: + pass + + def get_multi_table_options(self, connection, **kw): + return self._default_multi_reflect( + self.get_table_options, connection, **kw + ) + + def get_multi_columns(self, connection, **kw): + return self._default_multi_reflect(self.get_columns, connection, **kw) + + def get_multi_pk_constraint(self, connection, **kw): + return self._default_multi_reflect( + self.get_pk_constraint, connection, **kw + ) + + def get_multi_foreign_keys(self, connection, **kw): + return self._default_multi_reflect( + self.get_foreign_keys, connection, **kw + ) + + def get_multi_indexes(self, connection, **kw): + return self._default_multi_reflect(self.get_indexes, connection, **kw) + + def get_multi_unique_constraints(self, connection, **kw): + return self._default_multi_reflect( + self.get_unique_constraints, connection, **kw + ) + + def get_multi_check_constraints(self, connection, **kw): + return self._default_multi_reflect( + self.get_check_constraints, connection, **kw + ) + + def get_multi_table_comment(self, connection, **kw): + return self._default_multi_reflect( + self.get_table_comment, connection, **kw + ) + class StrCompileDialect(DefaultDialect): diff --git a/lib/sqlalchemy/engine/interfaces.py b/lib/sqlalchemy/engine/interfaces.py index b8e85b646..28ed03f99 100644 --- a/lib/sqlalchemy/engine/interfaces.py +++ b/lib/sqlalchemy/engine/interfaces.py @@ -15,7 +15,9 @@ from typing import Any from typing import Awaitable from typing import Callable from typing import ClassVar +from typing import Collection from typing import Dict +from typing import Iterable from typing import List from typing import Mapping from typing import MutableMapping @@ -324,7 +326,7 @@ class ReflectedColumn(TypedDict): nullable: bool """column nullability""" - default: str + default: Optional[str] """column default expression as a SQL string""" autoincrement: NotRequired[bool] @@ -343,11 +345,11 @@ class ReflectedColumn(TypedDict): comment: NotRequired[Optional[str]] """comment for the column, if present""" - computed: NotRequired[Optional[ReflectedComputed]] + computed: NotRequired[ReflectedComputed] """indicates this column is computed at insert (possibly update) time by the database.""" - identity: NotRequired[Optional[ReflectedIdentity]] + identity: NotRequired[ReflectedIdentity] """indicates this column is an IDENTITY column""" dialect_options: NotRequired[Dict[str, Any]] @@ -390,6 +392,9 @@ class ReflectedUniqueConstraint(TypedDict): column_names: List[str] """column names which comprise the constraint""" + duplicates_index: NotRequired[Optional[str]] + "Indicates if this unique constraint duplicates an index with this name" + dialect_options: NotRequired[Dict[str, Any]] """Additional dialect-specific options detected for this reflected object""" @@ -439,7 +444,7 @@ class ReflectedForeignKeyConstraint(TypedDict): referred_columns: List[str] """referenced column names""" - dialect_options: NotRequired[Dict[str, Any]] + options: NotRequired[Dict[str, Any]] """Additional dialect-specific options detected for this reflected object""" @@ -462,9 +467,8 @@ class ReflectedIndex(TypedDict): unique: bool """whether or not the index has a unique flag""" - duplicates_constraint: NotRequired[bool] - """boolean indicating this index mirrors a unique constraint of the same - name""" + duplicates_constraint: NotRequired[Optional[str]] + "Indicates if this index mirrors a unique constraint with this name" include_columns: NotRequired[List[str]] """columns to include in the INCLUDE clause for supporting databases. @@ -472,7 +476,7 @@ class ReflectedIndex(TypedDict): .. deprecated:: 2.0 Legacy value, will be replaced with - ``d["dialect_options"][<dialect name>]["include"]`` + ``d["dialect_options"]["<dialect name>_include"]`` """ @@ -494,7 +498,7 @@ class ReflectedTableComment(TypedDict): """ - text: str + text: Optional[str] """text of the comment""" @@ -547,6 +551,7 @@ class BindTyping(Enum): VersionInfoType = Tuple[Union[int, str], ...] +TableKey = Tuple[Optional[str], str] class Dialect(EventTarget): @@ -1040,7 +1045,7 @@ class Dialect(EventTarget): raise NotImplementedError() - def initialize(self, connection: "Connection") -> None: + def initialize(self, connection: Connection) -> None: """Called during strategized creation of the dialect with a connection. @@ -1060,9 +1065,14 @@ class Dialect(EventTarget): pass + if TYPE_CHECKING: + + def _overrides_default(self, method_name: str) -> bool: + ... + def get_columns( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1074,13 +1084,40 @@ class Dialect(EventTarget): information as a list of dictionaries corresponding to the :class:`.ReflectedColumn` dictionary. + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_columns`. + """ + + def get_multi_columns( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedColumn]]]: + """Return information about columns in all tables in the + given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_columns`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + """ raise NotImplementedError() def get_pk_constraint( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1093,13 +1130,41 @@ class Dialect(EventTarget): key information as a dictionary corresponding to the :class:`.ReflectedPrimaryKeyConstraint` dictionary. + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_pk_constraint`. + + """ + raise NotImplementedError() + + def get_multi_pk_constraint( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, ReflectedPrimaryKeyConstraint]]: + """Return information about primary key constraints in + all tables in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_pk_constraint`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 """ raise NotImplementedError() def get_foreign_keys( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1111,42 +1176,104 @@ class Dialect(EventTarget): key information as a list of dicts corresponding to the :class:`.ReflectedForeignKeyConstraint` dictionary. + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_foreign_keys`. + """ + + raise NotImplementedError() + + def get_multi_foreign_keys( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedForeignKeyConstraint]]]: + """Return information about foreign_keys in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_multi_foreign_keys`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + """ raise NotImplementedError() def get_table_names( - self, connection: "Connection", schema: Optional[str] = None, **kw: Any + self, connection: Connection, schema: Optional[str] = None, **kw: Any ) -> List[str]: - """Return a list of table names for ``schema``.""" + """Return a list of table names for ``schema``. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_table_names`. + + """ raise NotImplementedError() def get_temp_table_names( - self, connection: "Connection", schema: Optional[str] = None, **kw: Any + self, connection: Connection, schema: Optional[str] = None, **kw: Any ) -> List[str]: """Return a list of temporary table names on the given connection, if supported by the underlying backend. + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_temp_table_names`. + """ raise NotImplementedError() def get_view_names( - self, connection: "Connection", schema: Optional[str] = None, **kw: Any + self, connection: Connection, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + """Return a list of all non-materialized view names available in the + database. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_view_names`. + + :param schema: schema name to query, if not the default schema. + + """ + + raise NotImplementedError() + + def get_materialized_view_names( + self, connection: Connection, schema: Optional[str] = None, **kw: Any ) -> List[str]: - """Return a list of all view names available in the database. + """Return a list of all materialized view names available in the + database. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_materialized_view_names`. :param schema: schema name to query, if not the default schema. + + .. versionadded:: 2.0 + """ raise NotImplementedError() def get_sequence_names( - self, connection: "Connection", schema: Optional[str] = None, **kw: Any + self, connection: Connection, schema: Optional[str] = None, **kw: Any ) -> List[str]: """Return a list of all sequence names available in the database. + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_sequence_names`. + :param schema: schema name to query, if not the default schema. .. versionadded:: 1.4 @@ -1155,26 +1282,40 @@ class Dialect(EventTarget): raise NotImplementedError() def get_temp_view_names( - self, connection: "Connection", schema: Optional[str] = None, **kw: Any + self, connection: Connection, schema: Optional[str] = None, **kw: Any ) -> List[str]: """Return a list of temporary view names on the given connection, if supported by the underlying backend. + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_temp_view_names`. + """ raise NotImplementedError() + def get_schema_names(self, connection: Connection, **kw: Any) -> List[str]: + """Return a list of all schema names available in the database. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_schema_names`. + """ + raise NotImplementedError() + def get_view_definition( self, - connection: "Connection", + connection: Connection, view_name: str, schema: Optional[str] = None, **kw: Any, ) -> str: - """Return view definition. + """Return plain or materialized view definition. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_view_definition`. Given a :class:`_engine.Connection`, a string - `view_name`, and an optional string ``schema``, return the view + ``view_name``, and an optional string ``schema``, return the view definition. """ @@ -1182,7 +1323,7 @@ class Dialect(EventTarget): def get_indexes( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1194,13 +1335,42 @@ class Dialect(EventTarget): information as a list of dictionaries corresponding to the :class:`.ReflectedIndex` dictionary. + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_indexes`. + """ + + raise NotImplementedError() + + def get_multi_indexes( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedIndex]]]: + """Return information about indexes in in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_indexes`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + """ raise NotImplementedError() def get_unique_constraints( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1211,13 +1381,42 @@ class Dialect(EventTarget): unique constraint information as a list of dicts corresponding to the :class:`.ReflectedUniqueConstraint` dictionary. + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_unique_constraints`. + """ + + raise NotImplementedError() + + def get_multi_unique_constraints( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedUniqueConstraint]]]: + """Return information about unique constraints in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_unique_constraints`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + """ raise NotImplementedError() def get_check_constraints( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1228,26 +1427,86 @@ class Dialect(EventTarget): check constraint information as a list of dicts corresponding to the :class:`.ReflectedCheckConstraint` dictionary. + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_check_constraints`. + + .. versionadded:: 1.1.0 + + """ + + raise NotImplementedError() + + def get_multi_check_constraints( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, List[ReflectedCheckConstraint]]]: + """Return information about check constraints in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_multi_check_constraints`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + """ raise NotImplementedError() def get_table_options( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, - ) -> Optional[Dict[str, Any]]: - r"""Return the "options" for the table identified by ``table_name`` - as a dictionary. + ) -> Dict[str, Any]: + """Return a dictionary of options specified when ``table_name`` + was created. + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_table_options`. """ - return None + raise NotImplementedError() + + def get_multi_table_options( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, Dict[str, Any]]]: + """Return a dictionary of options specified when the tables in the + given schema were created. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_multi_table_options`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + raise NotImplementedError() def get_table_comment( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1258,6 +1517,8 @@ class Dialect(EventTarget): table comment information as a dictionary corresponding to the :class:`.ReflectedTableComment` dictionary. + This is an internal dialect method. Applications should use + :meth:`.Inspector.get_table_comment`. :raise: ``NotImplementedError`` for dialects that don't support comments. @@ -1268,6 +1529,33 @@ class Dialect(EventTarget): raise NotImplementedError() + def get_multi_table_comment( + self, + connection: Connection, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + **kw: Any, + ) -> Iterable[Tuple[TableKey, ReflectedTableComment]]: + """Return information about the table comment in all tables + in the given ``schema``. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.get_multi_table_comment`. + + .. note:: The :class:`_engine.DefaultDialect` provides a default + implementation that will call the single table method for + each object returned by :meth:`Dialect.get_table_names`, + :meth:`Dialect.get_view_names` or + :meth:`Dialect.get_materialized_view_names` depending on the + provided ``kind``. Dialects that want to support a faster + implementation should implement this method. + + .. versionadded:: 2.0 + + """ + + raise NotImplementedError() + def normalize_name(self, name: str) -> str: """convert the given name to lowercase if it is detected as case insensitive. @@ -1290,7 +1578,7 @@ class Dialect(EventTarget): def has_table( self, - connection: "Connection", + connection: Connection, table_name: str, schema: Optional[str] = None, **kw: Any, @@ -1327,21 +1615,24 @@ class Dialect(EventTarget): def has_index( self, - connection: "Connection", + connection: Connection, table_name: str, index_name: str, schema: Optional[str] = None, + **kw: Any, ) -> bool: """Check the existence of a particular index name in the database. Given a :class:`_engine.Connection` object, a string - ``table_name`` and string index name, return True if an index of the - given name on the given table exists, false otherwise. + ``table_name`` and string index name, return ``True`` if an index of + the given name on the given table exists, ``False`` otherwise. The :class:`.DefaultDialect` implements this in terms of the :meth:`.Dialect.has_table` and :meth:`.Dialect.get_indexes` methods, however dialects can implement a more performant version. + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.has_index`. .. versionadded:: 1.4 @@ -1351,7 +1642,7 @@ class Dialect(EventTarget): def has_sequence( self, - connection: "Connection", + connection: Connection, sequence_name: str, schema: Optional[str] = None, **kw: Any, @@ -1359,13 +1650,39 @@ class Dialect(EventTarget): """Check the existence of a particular sequence in the database. Given a :class:`_engine.Connection` object and a string - `sequence_name`, return True if the given sequence exists in - the database, False otherwise. + `sequence_name`, return ``True`` if the given sequence exists in + the database, ``False`` otherwise. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.has_sequence`. + """ + + raise NotImplementedError() + + def has_schema( + self, connection: Connection, schema_name: str, **kw: Any + ) -> bool: + """Check the existence of a particular schema name in the database. + + Given a :class:`_engine.Connection` object, a string + ``schema_name``, return ``True`` if a schema of the + given exists, ``False`` otherwise. + + The :class:`.DefaultDialect` implements this by checking + the presence of ``schema_name`` among the schemas returned by + :meth:`.Dialect.get_schema_names`, + however dialects can implement a more performant version. + + This is an internal dialect method. Applications should use + :meth:`_engine.Inspector.has_schema`. + + .. versionadded:: 2.0 + """ raise NotImplementedError() - def _get_server_version_info(self, connection: "Connection") -> Any: + def _get_server_version_info(self, connection: Connection) -> Any: """Retrieve the server version info from the given connection. This is used by the default implementation to populate the @@ -1376,7 +1693,7 @@ class Dialect(EventTarget): raise NotImplementedError() - def _get_default_schema_name(self, connection: "Connection") -> str: + def _get_default_schema_name(self, connection: Connection) -> str: """Return the string name of the currently selected schema from the given connection. @@ -1481,7 +1798,7 @@ class Dialect(EventTarget): raise NotImplementedError() - def do_savepoint(self, connection: "Connection", name: str) -> None: + def do_savepoint(self, connection: Connection, name: str) -> None: """Create a savepoint with the given name. :param connection: a :class:`_engine.Connection`. @@ -1492,7 +1809,7 @@ class Dialect(EventTarget): raise NotImplementedError() def do_rollback_to_savepoint( - self, connection: "Connection", name: str + self, connection: Connection, name: str ) -> None: """Rollback a connection to the named savepoint. @@ -1503,9 +1820,7 @@ class Dialect(EventTarget): raise NotImplementedError() - def do_release_savepoint( - self, connection: "Connection", name: str - ) -> None: + def do_release_savepoint(self, connection: Connection, name: str) -> None: """Release the named savepoint on a connection. :param connection: a :class:`_engine.Connection`. @@ -1514,7 +1829,7 @@ class Dialect(EventTarget): raise NotImplementedError() - def do_begin_twophase(self, connection: "Connection", xid: Any) -> None: + def do_begin_twophase(self, connection: Connection, xid: Any) -> None: """Begin a two phase transaction on the given connection. :param connection: a :class:`_engine.Connection`. @@ -1524,7 +1839,7 @@ class Dialect(EventTarget): raise NotImplementedError() - def do_prepare_twophase(self, connection: "Connection", xid: Any) -> None: + def do_prepare_twophase(self, connection: Connection, xid: Any) -> None: """Prepare a two phase transaction on the given connection. :param connection: a :class:`_engine.Connection`. @@ -1536,7 +1851,7 @@ class Dialect(EventTarget): def do_rollback_twophase( self, - connection: "Connection", + connection: Connection, xid: Any, is_prepared: bool = True, recover: bool = False, @@ -1555,7 +1870,7 @@ class Dialect(EventTarget): def do_commit_twophase( self, - connection: "Connection", + connection: Connection, xid: Any, is_prepared: bool = True, recover: bool = False, @@ -1573,7 +1888,7 @@ class Dialect(EventTarget): raise NotImplementedError() - def do_recover_twophase(self, connection: "Connection") -> List[Any]: + def do_recover_twophase(self, connection: Connection) -> List[Any]: """Recover list of uncommitted prepared two phase transaction identifiers on the given connection. diff --git a/lib/sqlalchemy/engine/reflection.py b/lib/sqlalchemy/engine/reflection.py index 4fc57d5f4..32c89106b 100644 --- a/lib/sqlalchemy/engine/reflection.py +++ b/lib/sqlalchemy/engine/reflection.py @@ -27,39 +27,148 @@ methods such as get_table_names, get_columns, etc. from __future__ import annotations import contextlib +from dataclasses import dataclass +from enum import auto +from enum import Flag +from enum import unique +from typing import Any +from typing import Callable +from typing import Collection +from typing import Dict +from typing import Generator +from typing import Iterable from typing import List from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union from .base import Connection from .base import Engine -from .interfaces import ReflectedColumn from .. import exc from .. import inspection from .. import sql from .. import util from ..sql import operators from ..sql import schema as sa_schema +from ..sql.cache_key import _ad_hoc_cache_key_from_args +from ..sql.elements import TextClause from ..sql.type_api import TypeEngine +from ..sql.visitors import InternalTraversal from ..util import topological +from ..util.typing import final + +if TYPE_CHECKING: + from .interfaces import Dialect + from .interfaces import ReflectedCheckConstraint + from .interfaces import ReflectedColumn + from .interfaces import ReflectedForeignKeyConstraint + from .interfaces import ReflectedIndex + from .interfaces import ReflectedPrimaryKeyConstraint + from .interfaces import ReflectedTableComment + from .interfaces import ReflectedUniqueConstraint + from .interfaces import TableKey + +_R = TypeVar("_R") @util.decorator -def cache(fn, self, con, *args, **kw): +def cache( + fn: Callable[..., _R], + self: Dialect, + con: Connection, + *args: Any, + **kw: Any, +) -> _R: info_cache = kw.get("info_cache", None) if info_cache is None: return fn(self, con, *args, **kw) + exclude = {"info_cache", "unreflectable"} key = ( fn.__name__, tuple(a for a in args if isinstance(a, str)), - tuple((k, v) for k, v in kw.items() if k != "info_cache"), + tuple((k, v) for k, v in kw.items() if k not in exclude), ) - ret = info_cache.get(key) + ret: _R = info_cache.get(key) if ret is None: ret = fn(self, con, *args, **kw) info_cache[key] = ret return ret +def flexi_cache( + *traverse_args: Tuple[str, InternalTraversal] +) -> Callable[[Callable[..., _R]], Callable[..., _R]]: + @util.decorator + def go( + fn: Callable[..., _R], + self: Dialect, + con: Connection, + *args: Any, + **kw: Any, + ) -> _R: + info_cache = kw.get("info_cache", None) + if info_cache is None: + return fn(self, con, *args, **kw) + key = _ad_hoc_cache_key_from_args((fn.__name__,), traverse_args, args) + ret: _R = info_cache.get(key) + if ret is None: + ret = fn(self, con, *args, **kw) + info_cache[key] = ret + return ret + + return go + + +@unique +class ObjectKind(Flag): + """Enumerator that indicates which kind of object to return when calling + the ``get_multi`` methods. + + This is a Flag enum, so custom combinations can be passed. For example, + to reflect tables and plain views ``ObjectKind.TABLE | ObjectKind.VIEW`` + may be used. + + .. note:: + Not all dialect may support all kind of object. If a dialect does + not support a particular object an empty dict is returned. + In case a dialect supports an object, but the requested method + is not applicable for the specified kind the default value + will be returned for each reflected object. For example reflecting + check constraints of view return a dict with all the views with + empty lists as values. + """ + + TABLE = auto() + "Reflect table objects" + VIEW = auto() + "Reflect plain view objects" + MATERIALIZED_VIEW = auto() + "Reflect materialized view object" + + ANY_VIEW = VIEW | MATERIALIZED_VIEW + "Reflect any kind of view objects" + ANY = TABLE | VIEW | MATERIALIZED_VIEW + "Reflect all type of objects" + + +@unique +class ObjectScope(Flag): + """Enumerator that indicates which scope to use when calling + the ``get_multi`` methods. + """ + + DEFAULT = auto() + "Include default scope" + TEMPORARY = auto() + "Include only temp scope" + ANY = DEFAULT | TEMPORARY + "Include both default and temp scope" + + @inspection._self_inspects class Inspector(inspection.Inspectable["Inspector"]): """Performs database schema inspection. @@ -85,6 +194,12 @@ class Inspector(inspection.Inspectable["Inspector"]): """ + bind: Union[Engine, Connection] + engine: Engine + _op_context_requires_connect: bool + dialect: Dialect + info_cache: Dict[Any, Any] + @util.deprecated( "1.4", "The __init__() method on :class:`_reflection.Inspector` " @@ -96,7 +211,7 @@ class Inspector(inspection.Inspectable["Inspector"]): "in order to " "acquire an :class:`_reflection.Inspector`.", ) - def __init__(self, bind): + def __init__(self, bind: Union[Engine, Connection]): """Initialize a new :class:`_reflection.Inspector`. :param bind: a :class:`~sqlalchemy.engine.Connection`, @@ -108,38 +223,51 @@ class Inspector(inspection.Inspectable["Inspector"]): :meth:`_reflection.Inspector.from_engine` """ - return self._init_legacy(bind) + self._init_legacy(bind) @classmethod - def _construct(cls, init, bind): + def _construct( + cls, init: Callable[..., Any], bind: Union[Engine, Connection] + ) -> Inspector: if hasattr(bind.dialect, "inspector"): - cls = bind.dialect.inspector + cls = bind.dialect.inspector # type: ignore[attr-defined] self = cls.__new__(cls) init(self, bind) return self - def _init_legacy(self, bind): + def _init_legacy(self, bind: Union[Engine, Connection]) -> None: if hasattr(bind, "exec_driver_sql"): - self._init_connection(bind) + self._init_connection(bind) # type: ignore[arg-type] else: - self._init_engine(bind) + self._init_engine(bind) # type: ignore[arg-type] - def _init_engine(self, engine): + def _init_engine(self, engine: Engine) -> None: self.bind = self.engine = engine engine.connect().close() self._op_context_requires_connect = True self.dialect = self.engine.dialect self.info_cache = {} - def _init_connection(self, connection): + def _init_connection(self, connection: Connection) -> None: self.bind = connection self.engine = connection.engine self._op_context_requires_connect = False self.dialect = self.engine.dialect self.info_cache = {} + def clear_cache(self) -> None: + """reset the cache for this :class:`.Inspector`. + + Inspection methods that have data cached will emit SQL queries + when next called to get new data. + + .. versionadded:: 2.0 + + """ + self.info_cache.clear() + @classmethod @util.deprecated( "1.4", @@ -152,7 +280,7 @@ class Inspector(inspection.Inspectable["Inspector"]): "in order to " "acquire an :class:`_reflection.Inspector`.", ) - def from_engine(cls, bind): + def from_engine(cls, bind: Engine) -> Inspector: """Construct a new dialect-specific Inspector object from the given engine or connection. @@ -172,15 +300,15 @@ class Inspector(inspection.Inspectable["Inspector"]): return cls._construct(cls._init_legacy, bind) @inspection._inspects(Engine) - def _engine_insp(bind): + def _engine_insp(bind: Engine) -> Inspector: # type: ignore[misc] return Inspector._construct(Inspector._init_engine, bind) @inspection._inspects(Connection) - def _connection_insp(bind): + def _connection_insp(bind: Connection) -> Inspector: # type: ignore[misc] return Inspector._construct(Inspector._init_connection, bind) @contextlib.contextmanager - def _operation_context(self): + def _operation_context(self) -> Generator[Connection, None, None]: """Return a context that optimizes for multiple operations on a single transaction. @@ -189,10 +317,11 @@ class Inspector(inspection.Inspectable["Inspector"]): :class:`_engine.Connection`. """ + conn: Connection if self._op_context_requires_connect: - conn = self.bind.connect() + conn = self.bind.connect() # type: ignore[union-attr] else: - conn = self.bind + conn = self.bind # type: ignore[assignment] try: yield conn finally: @@ -200,7 +329,7 @@ class Inspector(inspection.Inspectable["Inspector"]): conn.close() @contextlib.contextmanager - def _inspection_context(self): + def _inspection_context(self) -> Generator[Inspector, None, None]: """Return an :class:`_reflection.Inspector` from this one that will run all operations on a single connection. @@ -213,7 +342,7 @@ class Inspector(inspection.Inspectable["Inspector"]): yield sub_insp @property - def default_schema_name(self): + def default_schema_name(self) -> Optional[str]: """Return the default schema name presented by the dialect for the current engine's database user. @@ -223,30 +352,38 @@ class Inspector(inspection.Inspectable["Inspector"]): """ return self.dialect.default_schema_name - def get_schema_names(self): - """Return all schema names.""" + def get_schema_names(self, **kw: Any) -> List[str]: + r"""Return all schema names. - if hasattr(self.dialect, "get_schema_names"): - with self._operation_context() as conn: - return self.dialect.get_schema_names( - conn, info_cache=self.info_cache - ) - return [] + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + """ - def get_table_names(self, schema=None): - """Return all table names in referred to within a particular schema. + with self._operation_context() as conn: + return self.dialect.get_schema_names( + conn, info_cache=self.info_cache, **kw + ) + + def get_table_names( + self, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + r"""Return all table names within a particular schema. The names are expected to be real tables only, not views. Views are instead returned using the - :meth:`_reflection.Inspector.get_view_names` - method. - + :meth:`_reflection.Inspector.get_view_names` and/or + :meth:`_reflection.Inspector.get_materialized_view_names` + methods. :param schema: Schema name. If ``schema`` is left at ``None``, the database's default schema is used, else the named schema is searched. If the database does not support named schemas, behavior is undefined if ``schema`` is not passed as ``None``. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. .. seealso:: @@ -258,43 +395,105 @@ class Inspector(inspection.Inspectable["Inspector"]): with self._operation_context() as conn: return self.dialect.get_table_names( - conn, schema, info_cache=self.info_cache + conn, schema, info_cache=self.info_cache, **kw ) - def has_table(self, table_name, schema=None): - """Return True if the backend has a table or view of the given name. + def has_table( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> bool: + r"""Return True if the backend has a table or view of the given name. :param table_name: name of the table to check :param schema: schema name to query, if not the default schema. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. .. versionadded:: 1.4 - the :meth:`.Inspector.has_table` method replaces the :meth:`_engine.Engine.has_table` method. - .. versionchanged:: 2.0:: The method checks also for views. + .. versionchanged:: 2.0:: The method checks also for any type of + views (plain or materialized). In previous version this behaviour was dialect specific. New dialect suite tests were added to ensure all dialect conform with this behaviour. """ - # TODO: info_cache? with self._operation_context() as conn: - return self.dialect.has_table(conn, table_name, schema) + return self.dialect.has_table( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) - def has_sequence(self, sequence_name, schema=None): - """Return True if the backend has a table of the given name. + def has_sequence( + self, sequence_name: str, schema: Optional[str] = None, **kw: Any + ) -> bool: + r"""Return True if the backend has a sequence with the given name. - :param sequence_name: name of the table to check + :param sequence_name: name of the sequence to check :param schema: schema name to query, if not the default schema. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. .. versionadded:: 1.4 """ - # TODO: info_cache? with self._operation_context() as conn: - return self.dialect.has_sequence(conn, sequence_name, schema) + return self.dialect.has_sequence( + conn, sequence_name, schema, info_cache=self.info_cache, **kw + ) + + def has_index( + self, + table_name: str, + index_name: str, + schema: Optional[str] = None, + **kw: Any, + ) -> bool: + r"""Check the existence of a particular index name in the database. + + :param table_name: the name of the table the index belongs to + :param index_name: the name of the index to check + :param schema: schema name to query, if not the default schema. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 2.0 + + """ + with self._operation_context() as conn: + return self.dialect.has_index( + conn, + table_name, + index_name, + schema, + info_cache=self.info_cache, + **kw, + ) + + def has_schema(self, schema_name: str, **kw: Any) -> bool: + r"""Return True if the backend has a schema with the given name. + + :param schema_name: name of the schema to check + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 2.0 + + """ + with self._operation_context() as conn: + return self.dialect.has_schema( + conn, schema_name, info_cache=self.info_cache, **kw + ) - def get_sorted_table_and_fkc_names(self, schema=None): - """Return dependency-sorted table and foreign key constraint names in + def get_sorted_table_and_fkc_names( + self, + schema: Optional[str] = None, + **kw: Any, + ) -> List[Tuple[Optional[str], List[Tuple[str, Optional[str]]]]]: + r"""Return dependency-sorted table and foreign key constraint names in referred to within a particular schema. This will yield 2-tuples of @@ -309,6 +508,11 @@ class Inspector(inspection.Inspectable["Inspector"]): .. versionadded:: 1.0.- + :param schema: schema name to query, if not the default schema. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + .. seealso:: :meth:`_reflection.Inspector.get_table_names` @@ -317,24 +521,74 @@ class Inspector(inspection.Inspectable["Inspector"]): with an already-given :class:`_schema.MetaData`. """ - with self._operation_context() as conn: - tnames = self.dialect.get_table_names( - conn, schema, info_cache=self.info_cache + + return [ + ( + table_key[1] if table_key else None, + [(tname, fks) for (_, tname), fks in fk_collection], ) + for ( + table_key, + fk_collection, + ) in self.sort_tables_on_foreign_key_dependency( + consider_schemas=(schema,) + ) + ] - tuples = set() - remaining_fkcs = set() + def sort_tables_on_foreign_key_dependency( + self, + consider_schemas: Collection[Optional[str]] = (None,), + **kw: Any, + ) -> List[ + Tuple[ + Optional[Tuple[Optional[str], str]], + List[Tuple[Tuple[Optional[str], str], Optional[str]]], + ] + ]: + r"""Return dependency-sorted table and foreign key constraint names + referred to within multiple schemas. + + This method may be compared to + :meth:`.Inspector.get_sorted_table_and_fkc_names`, which + works on one schema at a time; here, the method is a generalization + that will consider multiple schemas at once including that it will + resolve for cross-schema foreign keys. + + .. versionadded:: 2.0 - fknames_for_table = {} - for tname in tnames: - fkeys = self.get_foreign_keys(tname, schema) - fknames_for_table[tname] = set([fk["name"] for fk in fkeys]) - for fkey in fkeys: - if tname != fkey["referred_table"]: - tuples.add((fkey["referred_table"], tname)) + """ + SchemaTab = Tuple[Optional[str], str] + + tuples: Set[Tuple[SchemaTab, SchemaTab]] = set() + remaining_fkcs: Set[Tuple[SchemaTab, Optional[str]]] = set() + fknames_for_table: Dict[SchemaTab, Set[Optional[str]]] = {} + tnames: List[SchemaTab] = [] + + for schname in consider_schemas: + schema_fkeys = self.get_multi_foreign_keys(schname, **kw) + tnames.extend(schema_fkeys) + for (_, tname), fkeys in schema_fkeys.items(): + fknames_for_table[(schname, tname)] = set( + [fk["name"] for fk in fkeys] + ) + for fkey in fkeys: + if ( + tname != fkey["referred_table"] + or schname != fkey["referred_schema"] + ): + tuples.add( + ( + ( + fkey["referred_schema"], + fkey["referred_table"], + ), + (schname, tname), + ) + ) try: candidate_sort = list(topological.sort(tuples, tnames)) except exc.CircularDependencyError as err: + edge: Tuple[SchemaTab, SchemaTab] for edge in err.edges: tuples.remove(edge) remaining_fkcs.update( @@ -342,16 +596,32 @@ class Inspector(inspection.Inspectable["Inspector"]): ) candidate_sort = list(topological.sort(tuples, tnames)) - return [ - (tname, fknames_for_table[tname].difference(remaining_fkcs)) - for tname in candidate_sort - ] + [(None, list(remaining_fkcs))] + ret: List[ + Tuple[Optional[SchemaTab], List[Tuple[SchemaTab, Optional[str]]]] + ] + ret = [ + ( + (schname, tname), + [ + ((schname, tname), fk) + for fk in fknames_for_table[(schname, tname)].difference( + name for _, name in remaining_fkcs + ) + ], + ) + for (schname, tname) in candidate_sort + ] + return ret + [(None, list(remaining_fkcs))] - def get_temp_table_names(self): - """Return a list of temporary table names for the current bind. + def get_temp_table_names(self, **kw: Any) -> List[str]: + r"""Return a list of temporary table names for the current bind. This method is unsupported by most dialects; currently - only SQLite implements it. + only Oracle, PostgreSQL and SQLite implements it. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. .. versionadded:: 1.0.0 @@ -359,28 +629,35 @@ class Inspector(inspection.Inspectable["Inspector"]): with self._operation_context() as conn: return self.dialect.get_temp_table_names( - conn, info_cache=self.info_cache + conn, info_cache=self.info_cache, **kw ) - def get_temp_view_names(self): - """Return a list of temporary view names for the current bind. + def get_temp_view_names(self, **kw: Any) -> List[str]: + r"""Return a list of temporary view names for the current bind. This method is unsupported by most dialects; currently - only SQLite implements it. + only PostgreSQL and SQLite implements it. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. .. versionadded:: 1.0.0 """ with self._operation_context() as conn: return self.dialect.get_temp_view_names( - conn, info_cache=self.info_cache + conn, info_cache=self.info_cache, **kw ) - def get_table_options(self, table_name, schema=None, **kw): - """Return a dictionary of options specified when the table of the + def get_table_options( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> Dict[str, Any]: + r"""Return a dictionary of options specified when the table of the given name was created. - This currently includes some options that apply to MySQL tables. + This currently includes some options that apply to MySQL and Oracle + tables. :param table_name: string name of the table. For special quoting, use :class:`.quoted_name`. @@ -389,60 +666,172 @@ class Inspector(inspection.Inspectable["Inspector"]): of the database connection. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dict with the table options. The returned keys depend on the + dialect in use. Each one is prefixed with the dialect name. + """ - if hasattr(self.dialect, "get_table_options"): - with self._operation_context() as conn: - return self.dialect.get_table_options( - conn, table_name, schema, info_cache=self.info_cache, **kw - ) - return {} + with self._operation_context() as conn: + return self.dialect.get_table_options( + conn, table_name, schema, info_cache=self.info_cache, **kw + ) + + def get_multi_table_options( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, Dict[str, Any]]: + r"""Return a dictionary of options specified when the tables in the + given schema were created. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + This currently includes some options that apply to MySQL and Oracle + tables. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if options of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are dictionaries with the table options. + The returned keys in each dict depend on the + dialect in use. Each one is prefixed with the dialect name. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + """ + with self._operation_context() as conn: + res = self.dialect.get_multi_table_options( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + return dict(res) - def get_view_names(self, schema=None): - """Return all view names in `schema`. + def get_view_names( + self, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + r"""Return all non-materialized view names in `schema`. :param schema: Optional, retrieve names from a non-default schema. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + + .. versionchanged:: 2.0 For those dialects that previously included + the names of materialized views in this list (currently PostgreSQL), + this method no longer returns the names of materialized views. + the :meth:`.Inspector.get_materialized_view_names` method should + be used instead. + + .. seealso:: + + :meth:`.Inspector.get_materialized_view_names` """ with self._operation_context() as conn: return self.dialect.get_view_names( - conn, schema, info_cache=self.info_cache + conn, schema, info_cache=self.info_cache, **kw + ) + + def get_materialized_view_names( + self, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + r"""Return all materialized view names in `schema`. + + :param schema: Optional, retrieve names from a non-default schema. + For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + .. versionadded:: 2.0 + + .. seealso:: + + :meth:`.Inspector.get_view_names` + + """ + + with self._operation_context() as conn: + return self.dialect.get_materialized_view_names( + conn, schema, info_cache=self.info_cache, **kw ) - def get_sequence_names(self, schema=None): - """Return all sequence names in `schema`. + def get_sequence_names( + self, schema: Optional[str] = None, **kw: Any + ) -> List[str]: + r"""Return all sequence names in `schema`. :param schema: Optional, retrieve names from a non-default schema. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. """ with self._operation_context() as conn: return self.dialect.get_sequence_names( - conn, schema, info_cache=self.info_cache + conn, schema, info_cache=self.info_cache, **kw ) - def get_view_definition(self, view_name, schema=None): - """Return definition for `view_name`. + def get_view_definition( + self, view_name: str, schema: Optional[str] = None, **kw: Any + ) -> str: + r"""Return definition for the plain or materialized view called + ``view_name``. + :param view_name: Name of the view. :param schema: Optional, retrieve names from a non-default schema. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. """ with self._operation_context() as conn: return self.dialect.get_view_definition( - conn, view_name, schema, info_cache=self.info_cache + conn, view_name, schema, info_cache=self.info_cache, **kw ) def get_columns( - self, table_name: str, schema: Optional[str] = None, **kw + self, table_name: str, schema: Optional[str] = None, **kw: Any ) -> List[ReflectedColumn]: - """Return information about columns in `table_name`. + r"""Return information about columns in ``table_name``. - Given a string `table_name` and an optional string `schema`, return - column information as a list of dicts with these keys: + Given a string ``table_name`` and an optional string ``schema``, + return column information as a list of dicts with these keys: * ``name`` - the column's name @@ -487,6 +876,10 @@ class Inspector(inspection.Inspectable["Inspector"]): of the database connection. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + :return: list of dictionaries, each representing the definition of a database column. @@ -496,17 +889,83 @@ class Inspector(inspection.Inspectable["Inspector"]): col_defs = self.dialect.get_columns( conn, table_name, schema, info_cache=self.info_cache, **kw ) - for col_def in col_defs: - # make this easy and only return instances for coltype - coltype = col_def["type"] - if not isinstance(coltype, TypeEngine): - col_def["type"] = coltype() + if col_defs: + self._instantiate_types([col_defs]) return col_defs - def get_pk_constraint(self, table_name, schema=None, **kw): - """Return information about primary key constraint on `table_name`. + def _instantiate_types( + self, data: Iterable[List[ReflectedColumn]] + ) -> None: + # make this easy and only return instances for coltype + for col_defs in data: + for col_def in col_defs: + coltype = col_def["type"] + if not isinstance(coltype, TypeEngine): + col_def["type"] = coltype() + + def get_multi_columns( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedColumn]]: + r"""Return information about columns in all objects in the given schema. + + The objects can be filtered by passing the names to use to + ``filter_names``. + + The column information is as described in + :meth:`Inspector.get_columns`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if columns of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of a database column. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + """ + + with self._operation_context() as conn: + table_col_defs = dict( + self.dialect.get_multi_columns( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + self._instantiate_types(table_col_defs.values()) + return table_col_defs + + def get_pk_constraint( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> ReflectedPrimaryKeyConstraint: + r"""Return information about primary key constraint in ``table_name``. - Given a string `table_name`, and an optional string `schema`, return + Given a string ``table_name``, and an optional string `schema`, return primary key information as a dictionary with these keys: * ``constrained_columns`` - @@ -522,16 +981,80 @@ class Inspector(inspection.Inspectable["Inspector"]): of the database connection. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary representing the definition of + a primary key constraint. + """ with self._operation_context() as conn: return self.dialect.get_pk_constraint( conn, table_name, schema, info_cache=self.info_cache, **kw ) - def get_foreign_keys(self, table_name, schema=None, **kw): - """Return information about foreign_keys in `table_name`. + def get_multi_pk_constraint( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, ReflectedPrimaryKeyConstraint]: + r"""Return information about primary key constraints in + all tables in the given schema. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + The primary key information is as described in + :meth:`Inspector.get_pk_constraint`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if primary keys of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are dictionaries, each representing the + definition of a primary key constraint. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + """ + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_pk_constraint( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_foreign_keys( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedForeignKeyConstraint]: + r"""Return information about foreign_keys in ``table_name``. - Given a string `table_name`, and an optional string `schema`, return + Given a string ``table_name``, and an optional string `schema`, return foreign key information as a list of dicts with these keys: * ``constrained_columns`` - @@ -557,6 +1080,13 @@ class Inspector(inspection.Inspectable["Inspector"]): of the database connection. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a list of dictionaries, each representing the + a foreign key definition. + """ with self._operation_context() as conn: @@ -564,10 +1094,68 @@ class Inspector(inspection.Inspectable["Inspector"]): conn, table_name, schema, info_cache=self.info_cache, **kw ) - def get_indexes(self, table_name, schema=None, **kw): - """Return information about indexes in `table_name`. + def get_multi_foreign_keys( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedForeignKeyConstraint]]: + r"""Return information about foreign_keys in all tables + in the given schema. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + The foreign key informations as described in + :meth:`Inspector.get_foreign_keys`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if foreign keys of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing + a foreign key definition. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_foreign_keys( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) - Given a string `table_name` and an optional string `schema`, return + def get_indexes( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedIndex]: + r"""Return information about indexes in ``table_name``. + + Given a string ``table_name`` and an optional string `schema`, return index information as a list of dicts with these keys: * ``name`` - @@ -598,6 +1186,13 @@ class Inspector(inspection.Inspectable["Inspector"]): of the database connection. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a list of dictionaries, each representing the + definition of an index. + """ with self._operation_context() as conn: @@ -605,10 +1200,71 @@ class Inspector(inspection.Inspectable["Inspector"]): conn, table_name, schema, info_cache=self.info_cache, **kw ) - def get_unique_constraints(self, table_name, schema=None, **kw): - """Return information about unique constraints in `table_name`. + def get_multi_indexes( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedIndex]]: + r"""Return information about indexes in in all objects + in the given schema. + + The objects can be filtered by passing the names to use to + ``filter_names``. + + The foreign key information is as described in + :meth:`Inspector.get_foreign_keys`. + + The indexes information as described in + :meth:`Inspector.get_indexes`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if indexes of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of an index. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_indexes( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_unique_constraints( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedUniqueConstraint]: + r"""Return information about unique constraints in ``table_name``. - Given a string `table_name` and an optional string `schema`, return + Given a string ``table_name`` and an optional string `schema`, return unique constraint information as a list of dicts with these keys: * ``name`` - @@ -624,6 +1280,13 @@ class Inspector(inspection.Inspectable["Inspector"]): of the database connection. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a list of dictionaries, each representing the + definition of an unique constraint. + """ with self._operation_context() as conn: @@ -631,8 +1294,66 @@ class Inspector(inspection.Inspectable["Inspector"]): conn, table_name, schema, info_cache=self.info_cache, **kw ) - def get_table_comment(self, table_name, schema=None, **kw): - """Return information about the table comment for ``table_name``. + def get_multi_unique_constraints( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedUniqueConstraint]]: + r"""Return information about unique constraints in all tables + in the given schema. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + The unique constraint information is as described in + :meth:`Inspector.get_unique_constraints`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if constraints of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of an unique constraint. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_unique_constraints( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_table_comment( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> ReflectedTableComment: + r"""Return information about the table comment for ``table_name``. Given a string ``table_name`` and an optional string ``schema``, return table comment information as a dictionary with these keys: @@ -643,8 +1364,20 @@ class Inspector(inspection.Inspectable["Inspector"]): Raises ``NotImplementedError`` for a dialect that does not support comments. - .. versionadded:: 1.2 + :param table_name: string name of the table. For special quoting, + use :class:`.quoted_name`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary, with the table comment. + .. versionadded:: 1.2 """ with self._operation_context() as conn: @@ -652,10 +1385,71 @@ class Inspector(inspection.Inspectable["Inspector"]): conn, table_name, schema, info_cache=self.info_cache, **kw ) - def get_check_constraints(self, table_name, schema=None, **kw): - """Return information about check constraints in `table_name`. + def get_multi_table_comment( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, ReflectedTableComment]: + r"""Return information about the table comment in all objects + in the given schema. + + The objects can be filtered by passing the names to use to + ``filter_names``. + + The comment information is as described in + :meth:`Inspector.get_table_comment`. + + Raises ``NotImplementedError`` for a dialect that does not support + comments. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if comments of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are dictionaries, representing the + table comments. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_table_comment( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + + def get_check_constraints( + self, table_name: str, schema: Optional[str] = None, **kw: Any + ) -> List[ReflectedCheckConstraint]: + r"""Return information about check constraints in ``table_name``. - Given a string `table_name` and an optional string `schema`, return + Given a string ``table_name`` and an optional string `schema`, return check constraint information as a list of dicts with these keys: * ``name`` - @@ -677,6 +1471,13 @@ class Inspector(inspection.Inspectable["Inspector"]): of the database connection. For special quoting, use :class:`.quoted_name`. + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a list of dictionaries, each representing the + definition of a check constraints. + .. versionadded:: 1.1.0 """ @@ -686,14 +1487,71 @@ class Inspector(inspection.Inspectable["Inspector"]): conn, table_name, schema, info_cache=self.info_cache, **kw ) + def get_multi_check_constraints( + self, + schema: Optional[str] = None, + filter_names: Optional[Sequence[str]] = None, + kind: ObjectKind = ObjectKind.TABLE, + scope: ObjectScope = ObjectScope.DEFAULT, + **kw: Any, + ) -> Dict[TableKey, List[ReflectedCheckConstraint]]: + r"""Return information about check constraints in all tables + in the given schema. + + The tables can be filtered by passing the names to use to + ``filter_names``. + + The check constraint information is as described in + :meth:`Inspector.get_check_constraints`. + + :param schema: string schema name; if omitted, uses the default schema + of the database connection. For special quoting, + use :class:`.quoted_name`. + + :param filter_names: optionally return information only for the + objects listed here. + + :param kind: a :class:`.ObjectKind` that specifies the type of objects + to reflect. Defaults to ``ObjectKind.TABLE``. + + :param scope: a :class:`.ObjectScope` that specifies if constraints of + default, temporary or any tables should be reflected. + Defaults to ``ObjectScope.DEFAULT``. + + :param \**kw: Additional keyword argument to pass to the dialect + specific implementation. See the documentation of the dialect + in use for more information. + + :return: a dictionary where the keys are two-tuple schema,table-name + and the values are list of dictionaries, each representing the + definition of a check constraints. + The schema is ``None`` if no schema is provided. + + .. versionadded:: 2.0 + """ + + with self._operation_context() as conn: + return dict( + self.dialect.get_multi_check_constraints( + conn, + schema=schema, + filter_names=filter_names, + kind=kind, + scope=scope, + info_cache=self.info_cache, + **kw, + ) + ) + def reflect_table( self, - table, - include_columns, - exclude_columns=(), - resolve_fks=True, - _extend_on=None, - ): + table: sa_schema.Table, + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str] = (), + resolve_fks: bool = True, + _extend_on: Optional[Set[sa_schema.Table]] = None, + _reflect_info: Optional[_ReflectionInfo] = None, + ) -> None: """Given a :class:`_schema.Table` object, load its internal constructs based on introspection. @@ -741,21 +1599,34 @@ class Inspector(inspection.Inspectable["Inspector"]): if k in table.dialect_kwargs ) + table_key = (schema, table_name) + if _reflect_info is None or table_key not in _reflect_info.columns: + _reflect_info = self._get_reflection_info( + schema, + filter_names=[table_name], + kind=ObjectKind.ANY, + scope=ObjectScope.ANY, + _reflect_info=_reflect_info, + **table.dialect_kwargs, + ) + if table_key in _reflect_info.unreflectable: + raise _reflect_info.unreflectable[table_key] + + if table_key not in _reflect_info.columns: + raise exc.NoSuchTableError(table_name) + # reflect table options, like mysql_engine - tbl_opts = self.get_table_options( - table_name, schema, **table.dialect_kwargs - ) - if tbl_opts: - # add additional kwargs to the Table if the dialect - # returned them - table._validate_dialect_kwargs(tbl_opts) + if _reflect_info.table_options: + tbl_opts = _reflect_info.table_options.get(table_key) + if tbl_opts: + # add additional kwargs to the Table if the dialect + # returned them + table._validate_dialect_kwargs(tbl_opts) found_table = False - cols_by_orig_name = {} + cols_by_orig_name: Dict[str, sa_schema.Column[Any]] = {} - for col_d in self.get_columns( - table_name, schema, **table.dialect_kwargs - ): + for col_d in _reflect_info.columns[table_key]: found_table = True self._reflect_column( @@ -771,12 +1642,12 @@ class Inspector(inspection.Inspectable["Inspector"]): raise exc.NoSuchTableError(table_name) self._reflect_pk( - table_name, schema, table, cols_by_orig_name, exclude_columns + _reflect_info, table_key, table, cols_by_orig_name, exclude_columns ) self._reflect_fk( - table_name, - schema, + _reflect_info, + table_key, table, cols_by_orig_name, include_columns, @@ -787,8 +1658,8 @@ class Inspector(inspection.Inspectable["Inspector"]): ) self._reflect_indexes( - table_name, - schema, + _reflect_info, + table_key, table, cols_by_orig_name, include_columns, @@ -797,8 +1668,8 @@ class Inspector(inspection.Inspectable["Inspector"]): ) self._reflect_unique_constraints( - table_name, - schema, + _reflect_info, + table_key, table, cols_by_orig_name, include_columns, @@ -807,8 +1678,8 @@ class Inspector(inspection.Inspectable["Inspector"]): ) self._reflect_check_constraints( - table_name, - schema, + _reflect_info, + table_key, table, cols_by_orig_name, include_columns, @@ -817,17 +1688,27 @@ class Inspector(inspection.Inspectable["Inspector"]): ) self._reflect_table_comment( - table_name, schema, table, reflection_options + _reflect_info, + table_key, + table, + reflection_options, ) def _reflect_column( - self, table, col_d, include_columns, exclude_columns, cols_by_orig_name - ): + self, + table: sa_schema.Table, + col_d: ReflectedColumn, + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + ) -> None: orig_name = col_d["name"] table.metadata.dispatch.column_reflect(self, table, col_d) - table.dispatch.column_reflect(self, table, col_d) + table.dispatch.column_reflect( # type: ignore[attr-defined] + self, table, col_d + ) # fetch name again as column_reflect is allowed to # change it @@ -840,7 +1721,7 @@ class Inspector(inspection.Inspectable["Inspector"]): coltype = col_d["type"] col_kw = dict( - (k, col_d[k]) + (k, col_d[k]) # type: ignore[literal-required] for k in [ "nullable", "autoincrement", @@ -856,15 +1737,20 @@ class Inspector(inspection.Inspectable["Inspector"]): col_kw.update(col_d["dialect_options"]) colargs = [] + default: Any if col_d.get("default") is not None: - default = col_d["default"] - if isinstance(default, sql.elements.TextClause): - default = sa_schema.DefaultClause(default, _reflected=True) - elif not isinstance(default, sa_schema.FetchedValue): + default_text = col_d["default"] + assert default_text is not None + if isinstance(default_text, TextClause): default = sa_schema.DefaultClause( - sql.text(col_d["default"]), _reflected=True + default_text, _reflected=True ) - + elif not isinstance(default_text, sa_schema.FetchedValue): + default = sa_schema.DefaultClause( + sql.text(default_text), _reflected=True + ) + else: + default = default_text colargs.append(default) if "computed" in col_d: @@ -872,11 +1758,8 @@ class Inspector(inspection.Inspectable["Inspector"]): colargs.append(computed) if "identity" in col_d: - computed = sa_schema.Identity(**col_d["identity"]) - colargs.append(computed) - - if "sequence" in col_d: - self._reflect_col_sequence(col_d, colargs) + identity = sa_schema.Identity(**col_d["identity"]) + colargs.append(identity) cols_by_orig_name[orig_name] = col = sa_schema.Column( name, coltype, *colargs, **col_kw @@ -886,23 +1769,15 @@ class Inspector(inspection.Inspectable["Inspector"]): col.primary_key = True table.append_column(col, replace_existing=True) - def _reflect_col_sequence(self, col_d, colargs): - if "sequence" in col_d: - # TODO: mssql is using this. - seq = col_d["sequence"] - sequence = sa_schema.Sequence(seq["name"], 1, 1) - if "start" in seq: - sequence.start = seq["start"] - if "increment" in seq: - sequence.increment = seq["increment"] - colargs.append(sequence) - def _reflect_pk( - self, table_name, schema, table, cols_by_orig_name, exclude_columns - ): - pk_cons = self.get_pk_constraint( - table_name, schema, **table.dialect_kwargs - ) + self, + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + exclude_columns: Collection[str], + ) -> None: + pk_cons = _reflect_info.pk_constraint.get(table_key) if pk_cons: pk_cols = [ cols_by_orig_name[pk] @@ -919,19 +1794,17 @@ class Inspector(inspection.Inspectable["Inspector"]): def _reflect_fk( self, - table_name, - schema, - table, - cols_by_orig_name, - include_columns, - exclude_columns, - resolve_fks, - _extend_on, - reflection_options, - ): - fkeys = self.get_foreign_keys( - table_name, schema, **table.dialect_kwargs - ) + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + resolve_fks: bool, + _extend_on: Optional[Set[sa_schema.Table]], + reflection_options: Dict[str, Any], + ) -> None: + fkeys = _reflect_info.foreign_keys.get(table_key, []) for fkey_d in fkeys: conname = fkey_d["name"] # look for columns by orig name in cols_by_orig_name, @@ -963,6 +1836,7 @@ class Inspector(inspection.Inspectable["Inspector"]): schema=referred_schema, autoload_with=self.bind, _extend_on=_extend_on, + _reflect_info=_reflect_info, **reflection_options, ) for column in referred_columns: @@ -977,6 +1851,7 @@ class Inspector(inspection.Inspectable["Inspector"]): autoload_with=self.bind, schema=sa_schema.BLANK_SCHEMA, _extend_on=_extend_on, + _reflect_info=_reflect_info, **reflection_options, ) for column in referred_columns: @@ -1005,16 +1880,16 @@ class Inspector(inspection.Inspectable["Inspector"]): def _reflect_indexes( self, - table_name, - schema, - table, - cols_by_orig_name, - include_columns, - exclude_columns, - reflection_options, - ): + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + reflection_options: Dict[str, Any], + ) -> None: # Indexes - indexes = self.get_indexes(table_name, schema) + indexes = _reflect_info.indexes.get(table_key, []) for index_d in indexes: name = index_d["name"] columns = index_d["column_names"] @@ -1034,6 +1909,7 @@ class Inspector(inspection.Inspectable["Inspector"]): continue # look for columns by orig name in cols_by_orig_name, # but support columns that are in-Python only as fallback + idx_col: Any idx_cols = [] for c in columns: try: @@ -1045,7 +1921,7 @@ class Inspector(inspection.Inspectable["Inspector"]): except KeyError: util.warn( "%s key '%s' was not located in " - "columns for table '%s'" % (flavor, c, table_name) + "columns for table '%s'" % (flavor, c, table.name) ) continue c_sorting = column_sorting.get(c, ()) @@ -1063,22 +1939,16 @@ class Inspector(inspection.Inspectable["Inspector"]): def _reflect_unique_constraints( self, - table_name, - schema, - table, - cols_by_orig_name, - include_columns, - exclude_columns, - reflection_options, - ): - + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + reflection_options: Dict[str, Any], + ) -> None: + constraints = _reflect_info.unique_constraints.get(table_key, []) # Unique Constraints - try: - constraints = self.get_unique_constraints(table_name, schema) - except NotImplementedError: - # optional dialect feature - return - for const_d in constraints: conname = const_d["name"] columns = const_d["column_names"] @@ -1104,7 +1974,7 @@ class Inspector(inspection.Inspectable["Inspector"]): except KeyError: util.warn( "unique constraint key '%s' was not located in " - "columns for table '%s'" % (c, table_name) + "columns for table '%s'" % (c, table.name) ) else: constrained_cols.append(constrained_col) @@ -1114,29 +1984,166 @@ class Inspector(inspection.Inspectable["Inspector"]): def _reflect_check_constraints( self, - table_name, - schema, - table, - cols_by_orig_name, - include_columns, - exclude_columns, - reflection_options, - ): - try: - constraints = self.get_check_constraints(table_name, schema) - except NotImplementedError: - # optional dialect feature - return - + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + cols_by_orig_name: Dict[str, sa_schema.Column[Any]], + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str], + reflection_options: Dict[str, Any], + ) -> None: + constraints = _reflect_info.check_constraints.get(table_key, []) for const_d in constraints: table.append_constraint(sa_schema.CheckConstraint(**const_d)) def _reflect_table_comment( - self, table_name, schema, table, reflection_options - ): - try: - comment_dict = self.get_table_comment(table_name, schema) - except NotImplementedError: - return + self, + _reflect_info: _ReflectionInfo, + table_key: TableKey, + table: sa_schema.Table, + reflection_options: Dict[str, Any], + ) -> None: + comment_dict = _reflect_info.table_comment.get(table_key) + if comment_dict: + table.comment = comment_dict["text"] + + def _get_reflection_info( + self, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, + available: Optional[Collection[str]] = None, + _reflect_info: Optional[_ReflectionInfo] = None, + **kw: Any, + ) -> _ReflectionInfo: + kw["schema"] = schema + + if filter_names and available and len(filter_names) > 100: + fraction = len(filter_names) / len(available) + else: + fraction = None + + unreflectable: Dict[TableKey, exc.UnreflectableTableError] + kw["unreflectable"] = unreflectable = {} + + has_result: bool = True + + def run( + meth: Any, + *, + optional: bool = False, + check_filter_names_from_meth: bool = False, + ) -> Any: + nonlocal has_result + # simple heuristic to improve reflection performance if a + # dialect implements multi_reflection: + # if more than 50% of the tables in the db are in filter_names + # load all the tables, since it's most likely faster to avoid + # a filter on that many tables. + if ( + fraction is None + or fraction <= 0.5 + or not self.dialect._overrides_default(meth.__name__) + ): + _fn = filter_names + else: + _fn = None + try: + if has_result: + res = meth(filter_names=_fn, **kw) + if check_filter_names_from_meth and not res: + # method returned no result data. + # skip any future call methods + has_result = False + else: + res = {} + except NotImplementedError: + if not optional: + raise + res = {} + return res + + info = _ReflectionInfo( + columns=run( + self.get_multi_columns, check_filter_names_from_meth=True + ), + pk_constraint=run(self.get_multi_pk_constraint), + foreign_keys=run(self.get_multi_foreign_keys), + indexes=run(self.get_multi_indexes), + unique_constraints=run( + self.get_multi_unique_constraints, optional=True + ), + table_comment=run(self.get_multi_table_comment, optional=True), + check_constraints=run( + self.get_multi_check_constraints, optional=True + ), + table_options=run(self.get_multi_table_options, optional=True), + unreflectable=unreflectable, + ) + if _reflect_info: + _reflect_info.update(info) + return _reflect_info else: - table.comment = comment_dict.get("text", None) + return info + + +@final +class ReflectionDefaults: + """provides blank default values for reflection methods.""" + + @classmethod + def columns(cls) -> List[ReflectedColumn]: + return [] + + @classmethod + def pk_constraint(cls) -> ReflectedPrimaryKeyConstraint: + return { # type: ignore # pep-655 not supported + "name": None, + "constrained_columns": [], + } + + @classmethod + def foreign_keys(cls) -> List[ReflectedForeignKeyConstraint]: + return [] + + @classmethod + def indexes(cls) -> List[ReflectedIndex]: + return [] + + @classmethod + def unique_constraints(cls) -> List[ReflectedUniqueConstraint]: + return [] + + @classmethod + def check_constraints(cls) -> List[ReflectedCheckConstraint]: + return [] + + @classmethod + def table_options(cls) -> Dict[str, Any]: + return {} + + @classmethod + def table_comment(cls) -> ReflectedTableComment: + return {"text": None} + + +@dataclass +class _ReflectionInfo: + columns: Dict[TableKey, List[ReflectedColumn]] + pk_constraint: Dict[TableKey, Optional[ReflectedPrimaryKeyConstraint]] + foreign_keys: Dict[TableKey, List[ReflectedForeignKeyConstraint]] + indexes: Dict[TableKey, List[ReflectedIndex]] + # optionals + unique_constraints: Dict[TableKey, List[ReflectedUniqueConstraint]] + table_comment: Dict[TableKey, Optional[ReflectedTableComment]] + check_constraints: Dict[TableKey, List[ReflectedCheckConstraint]] + table_options: Dict[TableKey, Dict[str, Any]] + unreflectable: Dict[TableKey, exc.UnreflectableTableError] + + def update(self, other: _ReflectionInfo) -> None: + for k, v in self.__dict__.items(): + ov = getattr(other, k) + if ov is not None: + if v is None: + setattr(self, k, ov) + else: + v.update(ov) diff --git a/lib/sqlalchemy/sql/base.py b/lib/sqlalchemy/sql/base.py index 391f74772..70c01d8d3 100644 --- a/lib/sqlalchemy/sql/base.py +++ b/lib/sqlalchemy/sql/base.py @@ -536,7 +536,7 @@ class DialectKWArgs: util.portable_instancemethod(self._kw_reg_for_dialect_cls) ) - def _validate_dialect_kwargs(self, kwargs: Any) -> None: + def _validate_dialect_kwargs(self, kwargs: Dict[str, Any]) -> None: # validate remaining kwargs that they all specify DB prefixes if not kwargs: diff --git a/lib/sqlalchemy/sql/cache_key.py b/lib/sqlalchemy/sql/cache_key.py index c16fbdae1..5922c2db0 100644 --- a/lib/sqlalchemy/sql/cache_key.py +++ b/lib/sqlalchemy/sql/cache_key.py @@ -12,6 +12,7 @@ from itertools import zip_longest import typing from typing import Any from typing import Dict +from typing import Iterable from typing import Iterator from typing import List from typing import MutableMapping @@ -546,6 +547,43 @@ class CacheKey(NamedTuple): return target_element.params(translate) +def _ad_hoc_cache_key_from_args( + tokens: Tuple[Any, ...], + traverse_args: Iterable[Tuple[str, InternalTraversal]], + args: Iterable[Any], +) -> Tuple[Any, ...]: + """a quick cache key generator used by reflection.flexi_cache.""" + bindparams: List[BindParameter[Any]] = [] + + _anon_map = anon_map() + + tup = tokens + + for (attrname, sym), arg in zip(traverse_args, args): + key = sym.name + visit_key = key.replace("dp_", "visit_") + + if arg is None: + tup += (attrname, None) + continue + + meth = getattr(_cache_key_traversal_visitor, visit_key) + if meth is CACHE_IN_PLACE: + tup += (attrname, arg) + elif meth in ( + CALL_GEN_CACHE_KEY, + STATIC_CACHE_KEY, + ANON_NAME, + PROPAGATE_ATTRS, + ): + raise NotImplementedError( + f"Haven't implemented symbol {meth} for ad-hoc key from args" + ) + else: + tup += meth(attrname, arg, None, _anon_map, bindparams) + return tup + + class _CacheKeyTraversal(HasTraversalDispatch): # very common elements are inlined into the main _get_cache_key() method # to produce a dramatic savings in Python function call overhead diff --git a/lib/sqlalchemy/sql/schema.py b/lib/sqlalchemy/sql/schema.py index c37b60003..1c4b3b0ce 100644 --- a/lib/sqlalchemy/sql/schema.py +++ b/lib/sqlalchemy/sql/schema.py @@ -38,6 +38,7 @@ import typing from typing import Any from typing import Callable from typing import cast +from typing import Collection from typing import Dict from typing import Iterable from typing import Iterator @@ -99,6 +100,7 @@ if typing.TYPE_CHECKING: from ..engine.interfaces import _ExecuteOptionsParameter from ..engine.interfaces import ExecutionContext from ..engine.mock import MockConnection + from ..engine.reflection import _ReflectionInfo from ..sql.selectable import FromClause _T = TypeVar("_T", bound="Any") @@ -493,7 +495,7 @@ class Table( keep_existing: bool = False, extend_existing: bool = False, resolve_fks: bool = True, - include_columns: Optional[Iterable[str]] = None, + include_columns: Optional[Collection[str]] = None, implicit_returning: bool = True, comment: Optional[str] = None, info: Optional[Dict[Any, Any]] = None, @@ -829,6 +831,7 @@ class Table( self.fullname = self.name self.implicit_returning = implicit_returning + _reflect_info = kw.pop("_reflect_info", None) self.comment = comment @@ -852,6 +855,7 @@ class Table( autoload_with, include_columns, _extend_on=_extend_on, + _reflect_info=_reflect_info, resolve_fks=resolve_fks, ) @@ -869,10 +873,11 @@ class Table( self, metadata: MetaData, autoload_with: Union[Engine, Connection], - include_columns: Optional[Iterable[str]], - exclude_columns: Iterable[str] = (), + include_columns: Optional[Collection[str]], + exclude_columns: Collection[str] = (), resolve_fks: bool = True, _extend_on: Optional[Set[Table]] = None, + _reflect_info: _ReflectionInfo | None = None, ) -> None: insp = inspection.inspect(autoload_with) with insp._inspection_context() as conn_insp: @@ -882,6 +887,7 @@ class Table( exclude_columns, resolve_fks, _extend_on=_extend_on, + _reflect_info=_reflect_info, ) @property @@ -924,6 +930,7 @@ class Table( autoload_replace = kwargs.pop("autoload_replace", True) schema = kwargs.pop("schema", None) _extend_on = kwargs.pop("_extend_on", None) + _reflect_info = kwargs.pop("_reflect_info", None) # these arguments are only used with _init() kwargs.pop("extend_existing", False) kwargs.pop("keep_existing", False) @@ -972,6 +979,7 @@ class Table( exclude_columns, resolve_fks, _extend_on=_extend_on, + _reflect_info=_reflect_info, ) self._extra_kwargs(**kwargs) @@ -3165,7 +3173,7 @@ class IdentityOptions: nominvalue: Optional[bool] = None, nomaxvalue: Optional[bool] = None, cycle: Optional[bool] = None, - cache: Optional[bool] = None, + cache: Optional[int] = None, order: Optional[bool] = None, ) -> None: """Construct a :class:`.IdentityOptions` object. @@ -5130,6 +5138,7 @@ class MetaData(HasSchemaAttr): sorted(self.tables.values(), key=lambda t: t.key) # type: ignore ) + @util.preload_module("sqlalchemy.engine.reflection") def reflect( self, bind: Union[Engine, Connection], @@ -5159,7 +5168,7 @@ class MetaData(HasSchemaAttr): is used, if any. :param views: - If True, also reflect views. + If True, also reflect views (materialized and plain). :param only: Optional. Load only a sub-set of available named tables. May be @@ -5225,7 +5234,7 @@ class MetaData(HasSchemaAttr): """ with inspection.inspect(bind)._inspection_context() as insp: - reflect_opts = { + reflect_opts: Any = { "autoload_with": insp, "extend_existing": extend_existing, "autoload_replace": autoload_replace, @@ -5241,15 +5250,21 @@ class MetaData(HasSchemaAttr): if schema is not None: reflect_opts["schema"] = schema + kind = util.preloaded.engine_reflection.ObjectKind.TABLE available: util.OrderedSet[str] = util.OrderedSet( insp.get_table_names(schema) ) if views: + kind = util.preloaded.engine_reflection.ObjectKind.ANY available.update(insp.get_view_names(schema)) + try: + available.update(insp.get_materialized_view_names(schema)) + except NotImplementedError: + pass if schema is not None: available_w_schema: util.OrderedSet[str] = util.OrderedSet( - ["%s.%s" % (schema, name) for name in available] + [f"{schema}.{name}" for name in available] ) else: available_w_schema = available @@ -5282,6 +5297,17 @@ class MetaData(HasSchemaAttr): for name in only if extend_existing or name not in current ] + # pass the available tables so the inspector can + # choose to ignore the filter_names + _reflect_info = insp._get_reflection_info( + schema=schema, + filter_names=load, + available=available, + kind=kind, + scope=util.preloaded.engine_reflection.ObjectScope.ANY, + **dialect_kwargs, + ) + reflect_opts["_reflect_info"] = _reflect_info for name in load: try: @@ -5489,7 +5515,7 @@ class Identity(IdentityOptions, FetchedValue, SchemaItem): nominvalue: Optional[bool] = None, nomaxvalue: Optional[bool] = None, cycle: Optional[bool] = None, - cache: Optional[bool] = None, + cache: Optional[int] = None, order: Optional[bool] = None, ) -> None: """Construct a GENERATED { ALWAYS | BY DEFAULT } AS IDENTITY DDL diff --git a/lib/sqlalchemy/testing/assertions.py b/lib/sqlalchemy/testing/assertions.py index 9888d7c18..937706363 100644 --- a/lib/sqlalchemy/testing/assertions.py +++ b/lib/sqlalchemy/testing/assertions.py @@ -644,13 +644,21 @@ class AssertsCompiledSQL: class ComparesTables: - def assert_tables_equal(self, table, reflected_table, strict_types=False): + def assert_tables_equal( + self, + table, + reflected_table, + strict_types=False, + strict_constraints=True, + ): assert len(table.c) == len(reflected_table.c) for c, reflected_c in zip(table.c, reflected_table.c): eq_(c.name, reflected_c.name) assert reflected_c is reflected_table.c[c.name] - eq_(c.primary_key, reflected_c.primary_key) - eq_(c.nullable, reflected_c.nullable) + + if strict_constraints: + eq_(c.primary_key, reflected_c.primary_key) + eq_(c.nullable, reflected_c.nullable) if strict_types: msg = "Type '%s' doesn't correspond to type '%s'" @@ -664,18 +672,20 @@ class ComparesTables: if isinstance(c.type, sqltypes.String): eq_(c.type.length, reflected_c.type.length) - eq_( - {f.column.name for f in c.foreign_keys}, - {f.column.name for f in reflected_c.foreign_keys}, - ) + if strict_constraints: + eq_( + {f.column.name for f in c.foreign_keys}, + {f.column.name for f in reflected_c.foreign_keys}, + ) if c.server_default: assert isinstance( reflected_c.server_default, schema.FetchedValue ) - assert len(table.primary_key) == len(reflected_table.primary_key) - for c in table.primary_key: - assert reflected_table.primary_key.columns[c.name] is not None + if strict_constraints: + assert len(table.primary_key) == len(reflected_table.primary_key) + for c in table.primary_key: + assert reflected_table.primary_key.columns[c.name] is not None def assert_types_base(self, c1, c2): assert c1.type._compare_type_affinity( diff --git a/lib/sqlalchemy/testing/plugin/pytestplugin.py b/lib/sqlalchemy/testing/plugin/pytestplugin.py index fa7d2ca19..cea07b305 100644 --- a/lib/sqlalchemy/testing/plugin/pytestplugin.py +++ b/lib/sqlalchemy/testing/plugin/pytestplugin.py @@ -741,13 +741,18 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): fn._sa_parametrize.append((argnames, pytest_params)) return fn else: + _fn_argnames = inspect.getfullargspec(fn).args[1:] if argnames is None: - _argnames = inspect.getfullargspec(fn).args[1:] + _argnames = _fn_argnames else: _argnames = re.split(r", *", argnames) if has_exclusions: - _argnames += ["_exclusions"] + existing_exl = sum( + 1 for n in _fn_argnames if n.startswith("_exclusions") + ) + current_exclusion_name = f"_exclusions_{existing_exl}" + _argnames += [current_exclusion_name] @_pytest_fn_decorator def check_exclusions(fn, *args, **kw): @@ -755,13 +760,10 @@ class PytestFixtureFunctions(plugin_base.FixtureFunctions): if _exclusions: exlu = exclusions.compound().add(*_exclusions) fn = exlu(fn) - return fn(*args[0:-1], **kw) - - def process_metadata(spec): - spec.args.append("_exclusions") + return fn(*args[:-1], **kw) fn = check_exclusions( - fn, add_positional_parameters=("_exclusions",) + fn, add_positional_parameters=(current_exclusion_name,) ) return pytest.mark.parametrize(_argnames, pytest_params)(fn) diff --git a/lib/sqlalchemy/testing/provision.py b/lib/sqlalchemy/testing/provision.py index d38437732..498d92a77 100644 --- a/lib/sqlalchemy/testing/provision.py +++ b/lib/sqlalchemy/testing/provision.py @@ -230,7 +230,39 @@ def drop_all_schema_objects(cfg, eng): drop_all_schema_objects_pre_tables(cfg, eng) + drop_views(cfg, eng) + + if config.requirements.materialized_views.enabled: + drop_materialized_views(cfg, eng) + inspector = inspect(eng) + + consider_schemas = (None,) + if config.requirements.schemas.enabled_for_config(cfg): + consider_schemas += (cfg.test_schema, cfg.test_schema_2) + util.drop_all_tables(eng, inspector, consider_schemas=consider_schemas) + + drop_all_schema_objects_post_tables(cfg, eng) + + if config.requirements.sequences.enabled_for_config(cfg): + with eng.begin() as conn: + for seq in inspector.get_sequence_names(): + conn.execute(ddl.DropSequence(schema.Sequence(seq))) + if config.requirements.schemas.enabled_for_config(cfg): + for schema_name in [cfg.test_schema, cfg.test_schema_2]: + for seq in inspector.get_sequence_names( + schema=schema_name + ): + conn.execute( + ddl.DropSequence( + schema.Sequence(seq, schema=schema_name) + ) + ) + + +def drop_views(cfg, eng): + inspector = inspect(eng) + try: view_names = inspector.get_view_names() except NotImplementedError: @@ -244,7 +276,7 @@ def drop_all_schema_objects(cfg, eng): if config.requirements.schemas.enabled_for_config(cfg): try: - view_names = inspector.get_view_names(schema="test_schema") + view_names = inspector.get_view_names(schema=cfg.test_schema) except NotImplementedError: pass else: @@ -255,32 +287,30 @@ def drop_all_schema_objects(cfg, eng): schema.Table( vname, schema.MetaData(), - schema="test_schema", + schema=cfg.test_schema, ) ) ) - util.drop_all_tables(eng, inspector) - if config.requirements.schemas.enabled_for_config(cfg): - util.drop_all_tables(eng, inspector, schema=cfg.test_schema) - util.drop_all_tables(eng, inspector, schema=cfg.test_schema_2) - drop_all_schema_objects_post_tables(cfg, eng) +def drop_materialized_views(cfg, eng): + inspector = inspect(eng) - if config.requirements.sequences.enabled_for_config(cfg): + mview_names = inspector.get_materialized_view_names() + + with eng.begin() as conn: + for vname in mview_names: + conn.exec_driver_sql(f"DROP MATERIALIZED VIEW {vname}") + + if config.requirements.schemas.enabled_for_config(cfg): + mview_names = inspector.get_materialized_view_names( + schema=cfg.test_schema + ) with eng.begin() as conn: - for seq in inspector.get_sequence_names(): - conn.execute(ddl.DropSequence(schema.Sequence(seq))) - if config.requirements.schemas.enabled_for_config(cfg): - for schema_name in [cfg.test_schema, cfg.test_schema_2]: - for seq in inspector.get_sequence_names( - schema=schema_name - ): - conn.execute( - ddl.DropSequence( - schema.Sequence(seq, schema=schema_name) - ) - ) + for vname in mview_names: + conn.exec_driver_sql( + f"DROP MATERIALIZED VIEW {cfg.test_schema}.{vname}" + ) @register.init diff --git a/lib/sqlalchemy/testing/requirements.py b/lib/sqlalchemy/testing/requirements.py index 4f9c73cf6..038f6e9bd 100644 --- a/lib/sqlalchemy/testing/requirements.py +++ b/lib/sqlalchemy/testing/requirements.py @@ -65,6 +65,25 @@ class SuiteRequirements(Requirements): return exclusions.open() @property + def foreign_keys_reflect_as_index(self): + """Target database creates an index that's reflected for + foreign keys.""" + + return exclusions.closed() + + @property + def unique_index_reflect_as_unique_constraints(self): + """Target database reflects unique indexes as unique constrains.""" + + return exclusions.closed() + + @property + def unique_constraints_reflect_as_index(self): + """Target database reflects unique constraints as indexes.""" + + return exclusions.closed() + + @property def table_value_constructor(self): """Database / dialect supports a query like:: @@ -629,6 +648,12 @@ class SuiteRequirements(Requirements): return self.schemas @property + def schema_create_delete(self): + """target database supports schema create and dropped with + 'CREATE SCHEMA' and 'DROP SCHEMA'""" + return exclusions.closed() + + @property def primary_key_constraint_reflection(self): return exclusions.open() @@ -693,6 +718,12 @@ class SuiteRequirements(Requirements): return exclusions.open() @property + def reflect_indexes_with_ascdesc(self): + """target database supports reflecting INDEX with per-column + ASC/DESC.""" + return exclusions.open() + + @property def indexes_with_expressions(self): """target database supports CREATE INDEX against SQL expressions.""" return exclusions.closed() @@ -1567,3 +1598,18 @@ class SuiteRequirements(Requirements): def json_deserializer_binary(self): "indicates if the json_deserializer function is called with bytes" return exclusions.closed() + + @property + def reflect_table_options(self): + """Target database must support reflecting table_options.""" + return exclusions.closed() + + @property + def materialized_views(self): + """Target database must support MATERIALIZED VIEWs.""" + return exclusions.closed() + + @property + def materialized_views_reflect_pk(self): + """Target database reflect MATERIALIZED VIEWs pks.""" + return exclusions.closed() diff --git a/lib/sqlalchemy/testing/schema.py b/lib/sqlalchemy/testing/schema.py index e4a92a732..46cbf4759 100644 --- a/lib/sqlalchemy/testing/schema.py +++ b/lib/sqlalchemy/testing/schema.py @@ -23,7 +23,7 @@ __all__ = ["Table", "Column"] table_options = {} -def Table(*args, **kw): +def Table(*args, **kw) -> schema.Table: """A schema.Table wrapper/hook for dialect-specific tweaks.""" test_opts = {k: kw.pop(k) for k in list(kw) if k.startswith("test_")} @@ -134,6 +134,19 @@ class eq_type_affinity: return self.target._type_affinity is not other._type_affinity +class eq_compile_type: + """similar to eq_type_affinity but uses compile""" + + def __init__(self, target): + self.target = target + + def __eq__(self, other): + return self.target == other.compile() + + def __ne__(self, other): + return self.target != other.compile() + + class eq_clause_element: """Helper to compare SQL structures based on compare()""" diff --git a/lib/sqlalchemy/testing/suite/test_reflection.py b/lib/sqlalchemy/testing/suite/test_reflection.py index b09b96227..7b8e2aa8b 100644 --- a/lib/sqlalchemy/testing/suite/test_reflection.py +++ b/lib/sqlalchemy/testing/suite/test_reflection.py @@ -7,6 +7,8 @@ import sqlalchemy as sa from .. import config from .. import engines from .. import eq_ +from .. import expect_raises +from .. import expect_raises_message from .. import expect_warnings from .. import fixtures from .. import is_ @@ -24,12 +26,19 @@ from ... import MetaData from ... import String from ... import testing from ... import types as sql_types +from ...engine import Inspector +from ...engine import ObjectKind +from ...engine import ObjectScope +from ...exc import NoSuchTableError +from ...exc import UnreflectableTableError from ...schema import DDL from ...schema import Index from ...sql.elements import quoted_name from ...sql.schema import BLANK_SCHEMA +from ...testing import ComparesTables from ...testing import is_false from ...testing import is_true +from ...testing import mock metadata, users = None, None @@ -61,6 +70,19 @@ class HasTableTest(fixtures.TablesTest): is_false(config.db.dialect.has_table(conn, "test_table_s")) is_false(config.db.dialect.has_table(conn, "nonexistent_table")) + def test_has_table_cache(self, metadata): + insp = inspect(config.db) + is_true(insp.has_table("test_table")) + nt = Table("new_table", metadata, Column("col", Integer)) + is_false(insp.has_table("new_table")) + nt.create(config.db) + try: + is_false(insp.has_table("new_table")) + insp.clear_cache() + is_true(insp.has_table("new_table")) + finally: + nt.drop(config.db) + @testing.requires.schemas def test_has_table_schema(self): with config.db.begin() as conn: @@ -117,6 +139,7 @@ class HasIndexTest(fixtures.TablesTest): metadata, Column("id", Integer, primary_key=True), Column("data", String(50)), + Column("data2", String(50)), ) Index("my_idx", tt.c.data) @@ -130,40 +153,56 @@ class HasIndexTest(fixtures.TablesTest): ) Index("my_idx_s", tt.c.data) - def test_has_index(self): - with config.db.begin() as conn: - assert config.db.dialect.has_index(conn, "test_table", "my_idx") - assert not config.db.dialect.has_index( - conn, "test_table", "my_idx_s" - ) - assert not config.db.dialect.has_index( - conn, "nonexistent_table", "my_idx" - ) - assert not config.db.dialect.has_index( - conn, "test_table", "nonexistent_idx" - ) + kind = testing.combinations("dialect", "inspector", argnames="kind") + + def _has_index(self, kind, conn): + if kind == "dialect": + return lambda *a, **k: config.db.dialect.has_index(conn, *a, **k) + else: + return inspect(conn).has_index + + @kind + def test_has_index(self, kind, connection, metadata): + meth = self._has_index(kind, connection) + assert meth("test_table", "my_idx") + assert not meth("test_table", "my_idx_s") + assert not meth("nonexistent_table", "my_idx") + assert not meth("test_table", "nonexistent_idx") + + assert not meth("test_table", "my_idx_2") + assert not meth("test_table_2", "my_idx_3") + idx = Index("my_idx_2", self.tables.test_table.c.data2) + tbl = Table( + "test_table_2", + metadata, + Column("foo", Integer), + Index("my_idx_3", "foo"), + ) + idx.create(connection) + tbl.create(connection) + try: + if kind == "inspector": + assert not meth("test_table", "my_idx_2") + assert not meth("test_table_2", "my_idx_3") + meth.__self__.clear_cache() + assert meth("test_table", "my_idx_2") is True + assert meth("test_table_2", "my_idx_3") is True + finally: + tbl.drop(connection) + idx.drop(connection) @testing.requires.schemas - def test_has_index_schema(self): - with config.db.begin() as conn: - assert config.db.dialect.has_index( - conn, "test_table", "my_idx_s", schema=config.test_schema - ) - assert not config.db.dialect.has_index( - conn, "test_table", "my_idx", schema=config.test_schema - ) - assert not config.db.dialect.has_index( - conn, - "nonexistent_table", - "my_idx_s", - schema=config.test_schema, - ) - assert not config.db.dialect.has_index( - conn, - "test_table", - "nonexistent_idx_s", - schema=config.test_schema, - ) + @kind + def test_has_index_schema(self, kind, connection): + meth = self._has_index(kind, connection) + assert meth("test_table", "my_idx_s", schema=config.test_schema) + assert not meth("test_table", "my_idx", schema=config.test_schema) + assert not meth( + "nonexistent_table", "my_idx_s", schema=config.test_schema + ) + assert not meth( + "test_table", "nonexistent_idx_s", schema=config.test_schema + ) class QuotedNameArgumentTest(fixtures.TablesTest): @@ -264,7 +303,12 @@ class QuotedNameArgumentTest(fixtures.TablesTest): def test_get_table_options(self, name): insp = inspect(config.db) - insp.get_table_options(name) + if testing.requires.reflect_table_options.enabled: + res = insp.get_table_options(name) + is_true(isinstance(res, dict)) + else: + with expect_raises(NotImplementedError): + res = insp.get_table_options(name) @quote_fixtures @testing.requires.view_column_reflection @@ -311,7 +355,37 @@ class QuotedNameArgumentTest(fixtures.TablesTest): assert insp.get_check_constraints(name) -class ComponentReflectionTest(fixtures.TablesTest): +def _multi_combination(fn): + schema = testing.combinations( + None, + ( + lambda: config.test_schema, + testing.requires.schemas, + ), + argnames="schema", + ) + scope = testing.combinations( + ObjectScope.DEFAULT, + ObjectScope.TEMPORARY, + ObjectScope.ANY, + argnames="scope", + ) + kind = testing.combinations( + ObjectKind.TABLE, + ObjectKind.VIEW, + ObjectKind.MATERIALIZED_VIEW, + ObjectKind.ANY, + ObjectKind.ANY_VIEW, + ObjectKind.TABLE | ObjectKind.VIEW, + ObjectKind.TABLE | ObjectKind.MATERIALIZED_VIEW, + argnames="kind", + ) + filter_names = testing.combinations(True, False, argnames="use_filter") + + return schema(scope(kind(filter_names(fn)))) + + +class ComponentReflectionTest(ComparesTables, fixtures.TablesTest): run_inserts = run_deletes = None __backend__ = True @@ -354,6 +428,7 @@ class ComponentReflectionTest(fixtures.TablesTest): "%susers.user_id" % schema_prefix, name="user_id_fk" ), ), + sa.CheckConstraint("test2 > 0", name="test2_gt_zero"), schema=schema, test_needs_fk=True, ) @@ -364,6 +439,8 @@ class ComponentReflectionTest(fixtures.TablesTest): Column("user_id", sa.INT, primary_key=True), Column("test1", sa.CHAR(5), nullable=False), Column("test2", sa.Float(), nullable=False), + Column("parent_user_id", sa.Integer), + sa.CheckConstraint("test2 > 0", name="test2_gt_zero"), schema=schema, test_needs_fk=True, ) @@ -375,9 +452,19 @@ class ComponentReflectionTest(fixtures.TablesTest): Column( "address_id", sa.Integer, - sa.ForeignKey("%semail_addresses.address_id" % schema_prefix), + sa.ForeignKey( + "%semail_addresses.address_id" % schema_prefix, + name="email_add_id_fg", + ), + ), + Column("data", sa.String(30), unique=True), + sa.CheckConstraint( + "address_id > 0 AND address_id < 1000", + name="address_id_gt_zero", + ), + sa.UniqueConstraint( + "address_id", "dingaling_id", name="zz_dingalings_multiple" ), - Column("data", sa.String(30)), schema=schema, test_needs_fk=True, ) @@ -388,7 +475,7 @@ class ComponentReflectionTest(fixtures.TablesTest): Column( "remote_user_id", sa.Integer, sa.ForeignKey(users.c.user_id) ), - Column("email_address", sa.String(20)), + Column("email_address", sa.String(20), index=True), sa.PrimaryKeyConstraint("address_id", name="email_ad_pk"), schema=schema, test_needs_fk=True, @@ -406,6 +493,12 @@ class ComponentReflectionTest(fixtures.TablesTest): schema=schema, comment=r"""the test % ' " \ table comment""", ) + Table( + "no_constraints", + metadata, + Column("data", sa.String(20)), + schema=schema, + ) if testing.requires.cross_schema_fk_reflection.enabled: if schema is None: @@ -449,7 +542,10 @@ class ComponentReflectionTest(fixtures.TablesTest): ) if testing.requires.index_reflection.enabled: - cls.define_index(metadata, users) + Index("users_t_idx", users.c.test1, users.c.test2, unique=True) + Index( + "users_all_idx", users.c.user_id, users.c.test2, users.c.test1 + ) if not schema: # test_needs_fk is at the moment to force MySQL InnoDB @@ -468,7 +564,10 @@ class ComponentReflectionTest(fixtures.TablesTest): test_needs_fk=True, ) - if testing.requires.indexes_with_ascdesc.enabled: + if ( + testing.requires.indexes_with_ascdesc.enabled + and testing.requires.reflect_indexes_with_ascdesc.enabled + ): Index("noncol_idx_nopk", noncol_idx_test_nopk.c.q.desc()) Index("noncol_idx_pk", noncol_idx_test_pk.c.q.desc()) @@ -478,11 +577,15 @@ class ComponentReflectionTest(fixtures.TablesTest): cls.define_temp_tables(metadata) @classmethod + def temp_table_name(cls): + return get_temp_table_name( + config, config.db, f"user_tmp_{config.ident}" + ) + + @classmethod def define_temp_tables(cls, metadata): kw = temp_table_keyword_args(config, config.db) - table_name = get_temp_table_name( - config, config.db, "user_tmp_%s" % config.ident - ) + table_name = cls.temp_table_name() user_tmp = Table( table_name, metadata, @@ -495,7 +598,7 @@ class ComponentReflectionTest(fixtures.TablesTest): # unique constraints created against temp tables in different # databases. # https://www.arbinada.com/en/node/1645 - sa.UniqueConstraint("name", name="user_tmp_uq_%s" % config.ident), + sa.UniqueConstraint("name", name=f"user_tmp_uq_{config.ident}"), sa.Index("user_tmp_ix", "foo"), **kw, ) @@ -514,32 +617,635 @@ class ComponentReflectionTest(fixtures.TablesTest): event.listen(user_tmp, "before_drop", DDL("drop view user_tmp_v")) @classmethod - def define_index(cls, metadata, users): - Index("users_t_idx", users.c.test1, users.c.test2) - Index("users_all_idx", users.c.user_id, users.c.test2, users.c.test1) - - @classmethod def define_views(cls, metadata, schema): - for table_name in ("users", "email_addresses"): + if testing.requires.materialized_views.enabled: + materialized = {"dingalings"} + else: + materialized = set() + for table_name in ("users", "email_addresses", "dingalings"): fullname = table_name if schema: - fullname = "%s.%s" % (schema, table_name) + fullname = f"{schema}.{table_name}" view_name = fullname + "_v" - query = "CREATE VIEW %s AS SELECT * FROM %s" % ( - view_name, - fullname, + prefix = "MATERIALIZED " if table_name in materialized else "" + query = ( + f"CREATE {prefix}VIEW {view_name} AS SELECT * FROM {fullname}" ) event.listen(metadata, "after_create", DDL(query)) + if table_name in materialized: + index_name = "mat_index" + if schema and testing.against("oracle"): + index_name = f"{schema}.{index_name}" + idx = f"CREATE INDEX {index_name} ON {view_name}(data)" + event.listen(metadata, "after_create", DDL(idx)) event.listen( - metadata, "before_drop", DDL("DROP VIEW %s" % view_name) + metadata, "before_drop", DDL(f"DROP {prefix}VIEW {view_name}") + ) + + def _resolve_kind(self, kind, tables, views, materialized): + res = {} + if ObjectKind.TABLE in kind: + res.update(tables) + if ObjectKind.VIEW in kind: + res.update(views) + if ObjectKind.MATERIALIZED_VIEW in kind: + res.update(materialized) + return res + + def _resolve_views(self, views, materialized): + if not testing.requires.view_column_reflection.enabled: + materialized.clear() + views.clear() + elif not testing.requires.materialized_views.enabled: + views.update(materialized) + materialized.clear() + + def _resolve_names(self, schema, scope, filter_names, values): + scope_filter = lambda _: True # noqa: E731 + if scope is ObjectScope.DEFAULT: + scope_filter = lambda k: "tmp" not in k[1] # noqa: E731 + if scope is ObjectScope.TEMPORARY: + scope_filter = lambda k: "tmp" in k[1] # noqa: E731 + + removed = { + None: {"remote_table", "remote_table_2"}, + testing.config.test_schema: { + "local_table", + "noncol_idx_test_nopk", + "noncol_idx_test_pk", + "user_tmp_v", + self.temp_table_name(), + }, + } + if not testing.requires.cross_schema_fk_reflection.enabled: + removed[None].add("local_table") + removed[testing.config.test_schema].update( + ["remote_table", "remote_table_2"] + ) + if not testing.requires.index_reflection.enabled: + removed[None].update( + ["noncol_idx_test_nopk", "noncol_idx_test_pk"] ) + if ( + not testing.requires.temp_table_reflection.enabled + or not testing.requires.temp_table_names.enabled + ): + removed[None].update(["user_tmp_v", self.temp_table_name()]) + if not testing.requires.temporary_views.enabled: + removed[None].update(["user_tmp_v"]) + + res = { + k: v + for k, v in values.items() + if scope_filter(k) + and k[1] not in removed[schema] + and (not filter_names or k[1] in filter_names) + } + return res + + def exp_options( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + materialized = {(schema, "dingalings_v"): mock.ANY} + views = { + (schema, "email_addresses_v"): mock.ANY, + (schema, "users_v"): mock.ANY, + (schema, "user_tmp_v"): mock.ANY, + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): mock.ANY, + (schema, "dingalings"): mock.ANY, + (schema, "email_addresses"): mock.ANY, + (schema, "comment_test"): mock.ANY, + (schema, "no_constraints"): mock.ANY, + (schema, "local_table"): mock.ANY, + (schema, "remote_table"): mock.ANY, + (schema, "remote_table_2"): mock.ANY, + (schema, "noncol_idx_test_nopk"): mock.ANY, + (schema, "noncol_idx_test_pk"): mock.ANY, + (schema, self.temp_table_name()): mock.ANY, + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + def exp_comments( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + empty = {"text": None} + materialized = {(schema, "dingalings_v"): empty} + views = { + (schema, "email_addresses_v"): empty, + (schema, "users_v"): empty, + (schema, "user_tmp_v"): empty, + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): empty, + (schema, "dingalings"): empty, + (schema, "email_addresses"): empty, + (schema, "comment_test"): { + "text": r"""the test % ' " \ table comment""" + }, + (schema, "no_constraints"): empty, + (schema, "local_table"): empty, + (schema, "remote_table"): empty, + (schema, "remote_table_2"): empty, + (schema, "noncol_idx_test_nopk"): empty, + (schema, "noncol_idx_test_pk"): empty, + (schema, self.temp_table_name()): empty, + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + def exp_columns( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def col( + name, auto=False, default=mock.ANY, comment=None, nullable=True + ): + res = { + "name": name, + "autoincrement": auto, + "type": mock.ANY, + "default": default, + "comment": comment, + "nullable": nullable, + } + if auto == "omit": + res.pop("autoincrement") + return res + + def pk(name, **kw): + kw = {"auto": True, "default": mock.ANY, "nullable": False, **kw} + return col(name, **kw) + + materialized = { + (schema, "dingalings_v"): [ + col("dingaling_id", auto="omit", nullable=mock.ANY), + col("address_id"), + col("data"), + ] + } + views = { + (schema, "email_addresses_v"): [ + col("address_id", auto="omit", nullable=mock.ANY), + col("remote_user_id"), + col("email_address"), + ], + (schema, "users_v"): [ + col("user_id", auto="omit", nullable=mock.ANY), + col("test1", nullable=mock.ANY), + col("test2", nullable=mock.ANY), + col("parent_user_id"), + ], + (schema, "user_tmp_v"): [ + col("id", auto="omit", nullable=mock.ANY), + col("name"), + col("foo"), + ], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + pk("user_id"), + col("test1", nullable=False), + col("test2", nullable=False), + col("parent_user_id"), + ], + (schema, "dingalings"): [ + pk("dingaling_id"), + col("address_id"), + col("data"), + ], + (schema, "email_addresses"): [ + pk("address_id"), + col("remote_user_id"), + col("email_address"), + ], + (schema, "comment_test"): [ + pk("id", comment="id comment"), + col("data", comment="data % comment"), + col( + "d2", + comment=r"""Comment types type speedily ' " \ '' Fun!""", + ), + ], + (schema, "no_constraints"): [col("data")], + (schema, "local_table"): [pk("id"), col("data"), col("remote_id")], + (schema, "remote_table"): [pk("id"), col("local_id"), col("data")], + (schema, "remote_table_2"): [pk("id"), col("data")], + (schema, "noncol_idx_test_nopk"): [col("q")], + (schema, "noncol_idx_test_pk"): [pk("id"), col("q")], + (schema, self.temp_table_name()): [ + pk("id"), + col("name"), + col("foo"), + ], + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_column_keys(self): + return {"name", "type", "nullable", "default"} + + def exp_pks( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def pk(*cols, name=mock.ANY): + return {"constrained_columns": list(cols), "name": name} + + empty = pk(name=None) + if testing.requires.materialized_views_reflect_pk.enabled: + materialized = {(schema, "dingalings_v"): pk("dingaling_id")} + else: + materialized = {(schema, "dingalings_v"): empty} + views = { + (schema, "email_addresses_v"): empty, + (schema, "users_v"): empty, + (schema, "user_tmp_v"): empty, + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): pk("user_id"), + (schema, "dingalings"): pk("dingaling_id"), + (schema, "email_addresses"): pk("address_id", name="email_ad_pk"), + (schema, "comment_test"): pk("id"), + (schema, "no_constraints"): empty, + (schema, "local_table"): pk("id"), + (schema, "remote_table"): pk("id"), + (schema, "remote_table_2"): pk("id"), + (schema, "noncol_idx_test_nopk"): empty, + (schema, "noncol_idx_test_pk"): pk("id"), + (schema, self.temp_table_name()): pk("id"), + } + if not testing.requires.reflects_pk_names.enabled: + for val in tables.values(): + if val["name"] is not None: + val["name"] = mock.ANY + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_pk_keys(self): + return {"name", "constrained_columns"} + + def exp_fks( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + class tt: + def __eq__(self, other): + return ( + other is None + or config.db.dialect.default_schema_name == other + ) + + def fk(cols, ref_col, ref_table, ref_schema=schema, name=mock.ANY): + return { + "constrained_columns": cols, + "referred_columns": ref_col, + "name": name, + "options": mock.ANY, + "referred_schema": ref_schema + if ref_schema is not None + else tt(), + "referred_table": ref_table, + } + + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + fk(["parent_user_id"], ["user_id"], "users", name="user_id_fk") + ], + (schema, "dingalings"): [ + fk( + ["address_id"], + ["address_id"], + "email_addresses", + name="email_add_id_fg", + ) + ], + (schema, "email_addresses"): [ + fk(["remote_user_id"], ["user_id"], "users") + ], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [ + fk( + ["remote_id"], + ["id"], + "remote_table_2", + ref_schema=config.test_schema, + ) + ], + (schema, "remote_table"): [ + fk(["local_id"], ["id"], "local_table", ref_schema=None) + ], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [], + (schema, "noncol_idx_test_pk"): [], + (schema, self.temp_table_name()): [], + } + if not testing.requires.self_referential_foreign_keys.enabled: + tables[(schema, "users")].clear() + if not testing.requires.named_constraints.enabled: + for vals in tables.values(): + for val in vals: + if val["name"] is not mock.ANY: + val["name"] = mock.ANY + + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_fk_keys(self): + return { + "name", + "constrained_columns", + "referred_schema", + "referred_table", + "referred_columns", + } + + def exp_indexes( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + def idx( + *cols, + name, + unique=False, + column_sorting=None, + duplicates=False, + fk=False, + ): + fk_req = testing.requires.foreign_keys_reflect_as_index + dup_req = testing.requires.unique_constraints_reflect_as_index + if (fk and not fk_req.enabled) or ( + duplicates and not dup_req.enabled + ): + return () + res = { + "unique": unique, + "column_names": list(cols), + "name": name, + "dialect_options": mock.ANY, + "include_columns": [], + } + if column_sorting: + res["column_sorting"] = {"q": ("desc",)} + if duplicates: + res["duplicates_constraint"] = name + return [res] + + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + if materialized: + materialized[(schema, "dingalings_v")].extend( + idx("data", name="mat_index") + ) + tables = { + (schema, "users"): [ + *idx("parent_user_id", name="user_id_fk", fk=True), + *idx("user_id", "test2", "test1", name="users_all_idx"), + *idx("test1", "test2", name="users_t_idx", unique=True), + ], + (schema, "dingalings"): [ + *idx("data", name=mock.ANY, unique=True, duplicates=True), + *idx( + "address_id", + "dingaling_id", + name="zz_dingalings_multiple", + unique=True, + duplicates=True, + ), + ], + (schema, "email_addresses"): [ + *idx("email_address", name=mock.ANY), + *idx("remote_user_id", name=mock.ANY, fk=True), + ], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [ + *idx("remote_id", name=mock.ANY, fk=True) + ], + (schema, "remote_table"): [ + *idx("local_id", name=mock.ANY, fk=True) + ], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [ + *idx( + "q", + name="noncol_idx_nopk", + column_sorting={"q": ("desc",)}, + ) + ], + (schema, "noncol_idx_test_pk"): [ + *idx( + "q", name="noncol_idx_pk", column_sorting={"q": ("desc",)} + ) + ], + (schema, self.temp_table_name()): [ + *idx("foo", name="user_tmp_ix"), + *idx( + "name", + name=f"user_tmp_uq_{config.ident}", + duplicates=True, + unique=True, + ), + ], + } + if ( + not testing.requires.indexes_with_ascdesc.enabled + or not testing.requires.reflect_indexes_with_ascdesc.enabled + ): + tables[(schema, "noncol_idx_test_nopk")].clear() + tables[(schema, "noncol_idx_test_pk")].clear() + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_index_keys(self): + return {"name", "column_names", "unique"} + + def exp_ucs( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + all_=False, + ): + def uc(*cols, name, duplicates_index=None, is_index=False): + req = testing.requires.unique_index_reflect_as_unique_constraints + if is_index and not req.enabled: + return () + res = { + "column_names": list(cols), + "name": name, + } + if duplicates_index: + res["duplicates_index"] = duplicates_index + return [res] + + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [ + *uc( + "test1", + "test2", + name="users_t_idx", + duplicates_index="users_t_idx", + is_index=True, + ) + ], + (schema, "dingalings"): [ + *uc("data", name=mock.ANY, duplicates_index=mock.ANY), + *uc( + "address_id", + "dingaling_id", + name="zz_dingalings_multiple", + duplicates_index="zz_dingalings_multiple", + ), + ], + (schema, "email_addresses"): [], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [], + (schema, "remote_table"): [], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [], + (schema, "noncol_idx_test_pk"): [], + (schema, self.temp_table_name()): [ + *uc("name", name=f"user_tmp_uq_{config.ident}") + ], + } + if all_: + return {**materialized, **views, **tables} + else: + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_unique_cst_keys(self): + return {"name", "column_names"} + + def exp_ccs( + self, + schema=None, + scope=ObjectScope.ANY, + kind=ObjectKind.ANY, + filter_names=None, + ): + class tt(str): + def __eq__(self, other): + res = ( + other.lower() + .replace("(", "") + .replace(")", "") + .replace("`", "") + ) + return self in res + + def cc(text, name): + return {"sqltext": tt(text), "name": name} + + # print({1: "test2 > (0)::double precision"} == {1: tt("test2 > 0")}) + # assert 0 + materialized = {(schema, "dingalings_v"): []} + views = { + (schema, "email_addresses_v"): [], + (schema, "users_v"): [], + (schema, "user_tmp_v"): [], + } + self._resolve_views(views, materialized) + tables = { + (schema, "users"): [cc("test2 > 0", "test2_gt_zero")], + (schema, "dingalings"): [ + cc( + "address_id > 0 and address_id < 1000", + name="address_id_gt_zero", + ), + ], + (schema, "email_addresses"): [], + (schema, "comment_test"): [], + (schema, "no_constraints"): [], + (schema, "local_table"): [], + (schema, "remote_table"): [], + (schema, "remote_table_2"): [], + (schema, "noncol_idx_test_nopk"): [], + (schema, "noncol_idx_test_pk"): [], + (schema, self.temp_table_name()): [], + } + res = self._resolve_kind(kind, tables, views, materialized) + res = self._resolve_names(schema, scope, filter_names, res) + return res + + @property + def _required_cc_keys(self): + return {"name", "sqltext"} @testing.requires.schema_reflection - def test_get_schema_names(self): - insp = inspect(self.bind) + def test_get_schema_names(self, connection): + insp = inspect(connection) - self.assert_(testing.config.test_schema in insp.get_schema_names()) + is_true(testing.config.test_schema in insp.get_schema_names()) + + @testing.requires.schema_reflection + def test_has_schema(self, connection): + insp = inspect(connection) + + is_true(insp.has_schema(testing.config.test_schema)) + is_false(insp.has_schema("sa_fake_schema_foo")) @testing.requires.schema_reflection def test_get_schema_names_w_translate_map(self, connection): @@ -553,7 +1259,37 @@ class ComponentReflectionTest(fixtures.TablesTest): ) insp = inspect(connection) - self.assert_(testing.config.test_schema in insp.get_schema_names()) + is_true(testing.config.test_schema in insp.get_schema_names()) + + @testing.requires.schema_reflection + def test_has_schema_w_translate_map(self, connection): + connection = connection.execution_options( + schema_translate_map={ + "foo": "bar", + BLANK_SCHEMA: testing.config.test_schema, + } + ) + insp = inspect(connection) + + is_true(insp.has_schema(testing.config.test_schema)) + is_false(insp.has_schema("sa_fake_schema_foo")) + + @testing.requires.schema_reflection + @testing.requires.schema_create_delete + def test_schema_cache(self, connection): + insp = inspect(connection) + + is_false("foo_bar" in insp.get_schema_names()) + is_false(insp.has_schema("foo_bar")) + connection.execute(DDL("CREATE SCHEMA foo_bar")) + try: + is_false("foo_bar" in insp.get_schema_names()) + is_false(insp.has_schema("foo_bar")) + insp.clear_cache() + is_true("foo_bar" in insp.get_schema_names()) + is_true(insp.has_schema("foo_bar")) + finally: + connection.execute(DDL("DROP SCHEMA foo_bar")) @testing.requires.schema_reflection def test_dialect_initialize(self): @@ -562,113 +1298,115 @@ class ComponentReflectionTest(fixtures.TablesTest): assert hasattr(engine.dialect, "default_schema_name") @testing.requires.schema_reflection - def test_get_default_schema_name(self): - insp = inspect(self.bind) - eq_(insp.default_schema_name, self.bind.dialect.default_schema_name) + def test_get_default_schema_name(self, connection): + insp = inspect(connection) + eq_(insp.default_schema_name, connection.dialect.default_schema_name) - @testing.requires.foreign_key_constraint_reflection @testing.combinations( - (None, True, False, False), - (None, True, False, True, testing.requires.schemas), - ("foreign_key", True, False, False), - (None, False, True, False), - (None, False, True, True, testing.requires.schemas), - (None, True, True, False), - (None, True, True, True, testing.requires.schemas), - argnames="order_by,include_plain,include_views,use_schema", + None, + ("foreign_key", testing.requires.foreign_key_constraint_reflection), + argnames="order_by", ) - def test_get_table_names( - self, connection, order_by, include_plain, include_views, use_schema - ): + @testing.combinations( + (True, testing.requires.schemas), False, argnames="use_schema" + ) + def test_get_table_names(self, connection, order_by, use_schema): if use_schema: schema = config.test_schema else: schema = None - _ignore_tables = [ + _ignore_tables = { "comment_test", "noncol_idx_test_pk", "noncol_idx_test_nopk", "local_table", "remote_table", "remote_table_2", - ] + "no_constraints", + } insp = inspect(connection) - if include_views: - table_names = insp.get_view_names(schema) - table_names.sort() - answer = ["email_addresses_v", "users_v"] - eq_(sorted(table_names), answer) + if order_by: + tables = [ + rec[0] + for rec in insp.get_sorted_table_and_fkc_names(schema) + if rec[0] + ] + else: + tables = insp.get_table_names(schema) + table_names = [t for t in tables if t not in _ignore_tables] - if include_plain: - if order_by: - tables = [ - rec[0] - for rec in insp.get_sorted_table_and_fkc_names(schema) - if rec[0] - ] - else: - tables = insp.get_table_names(schema) - table_names = [t for t in tables if t not in _ignore_tables] + if order_by == "foreign_key": + answer = ["users", "email_addresses", "dingalings"] + eq_(table_names, answer) + else: + answer = ["dingalings", "email_addresses", "users"] + eq_(sorted(table_names), answer) - if order_by == "foreign_key": - answer = ["users", "email_addresses", "dingalings"] - eq_(table_names, answer) - else: - answer = ["dingalings", "email_addresses", "users"] - eq_(sorted(table_names), answer) + @testing.combinations( + (True, testing.requires.schemas), False, argnames="use_schema" + ) + def test_get_view_names(self, connection, use_schema): + insp = inspect(connection) + if use_schema: + schema = config.test_schema + else: + schema = None + table_names = insp.get_view_names(schema) + if testing.requires.materialized_views.enabled: + eq_(sorted(table_names), ["email_addresses_v", "users_v"]) + eq_(insp.get_materialized_view_names(schema), ["dingalings_v"]) + else: + answer = ["dingalings_v", "email_addresses_v", "users_v"] + eq_(sorted(table_names), answer) @testing.requires.temp_table_names - def test_get_temp_table_names(self): - insp = inspect(self.bind) + def test_get_temp_table_names(self, connection): + insp = inspect(connection) temp_table_names = insp.get_temp_table_names() - eq_(sorted(temp_table_names), ["user_tmp_%s" % config.ident]) + eq_(sorted(temp_table_names), [f"user_tmp_{config.ident}"]) @testing.requires.view_reflection - @testing.requires.temp_table_names @testing.requires.temporary_views - def test_get_temp_view_names(self): - insp = inspect(self.bind) + def test_get_temp_view_names(self, connection): + insp = inspect(connection) temp_table_names = insp.get_temp_view_names() eq_(sorted(temp_table_names), ["user_tmp_v"]) @testing.requires.comment_reflection - def test_get_comments(self): - self._test_get_comments() + def test_get_comments(self, connection): + self._test_get_comments(connection) @testing.requires.comment_reflection @testing.requires.schemas - def test_get_comments_with_schema(self): - self._test_get_comments(testing.config.test_schema) - - def _test_get_comments(self, schema=None): - insp = inspect(self.bind) + def test_get_comments_with_schema(self, connection): + self._test_get_comments(connection, testing.config.test_schema) + def _test_get_comments(self, connection, schema=None): + insp = inspect(connection) + exp = self.exp_comments(schema=schema) eq_( insp.get_table_comment("comment_test", schema=schema), - {"text": r"""the test % ' " \ table comment"""}, + exp[(schema, "comment_test")], ) - eq_(insp.get_table_comment("users", schema=schema), {"text": None}) + eq_( + insp.get_table_comment("users", schema=schema), + exp[(schema, "users")], + ) eq_( - [ - {"name": rec["name"], "comment": rec["comment"]} - for rec in insp.get_columns("comment_test", schema=schema) - ], - [ - {"comment": "id comment", "name": "id"}, - {"comment": "data % comment", "name": "data"}, - { - "comment": ( - r"""Comment types type speedily ' " \ '' Fun!""" - ), - "name": "d2", - }, - ], + insp.get_table_comment("comment_test", schema=schema), + exp[(schema, "comment_test")], + ) + + no_cst = self.tables.no_constraints.name + eq_( + insp.get_table_comment(no_cst, schema=schema), + exp[(schema, no_cst)], ) @testing.combinations( @@ -691,7 +1429,7 @@ class ComponentReflectionTest(fixtures.TablesTest): users, addresses = (self.tables.users, self.tables.email_addresses) if use_views: - table_names = ["users_v", "email_addresses_v"] + table_names = ["users_v", "email_addresses_v", "dingalings_v"] else: table_names = ["users", "email_addresses"] @@ -699,7 +1437,7 @@ class ComponentReflectionTest(fixtures.TablesTest): for table_name, table in zip(table_names, (users, addresses)): schema_name = schema cols = insp.get_columns(table_name, schema=schema_name) - self.assert_(len(cols) > 0, len(cols)) + is_true(len(cols) > 0, len(cols)) # should be in order @@ -721,7 +1459,7 @@ class ComponentReflectionTest(fixtures.TablesTest): # assert that the desired type and return type share # a base within one of the generic types. - self.assert_( + is_true( len( set(ctype.__mro__) .intersection(ctype_def.__mro__) @@ -745,15 +1483,29 @@ class ComponentReflectionTest(fixtures.TablesTest): if not col.primary_key: assert cols[i]["default"] is None + # The case of a table with no column + # is tested below in TableNoColumnsTest + @testing.requires.temp_table_reflection - def test_get_temp_table_columns(self): - table_name = get_temp_table_name( - config, self.bind, "user_tmp_%s" % config.ident + def test_reflect_table_temp_table(self, connection): + + table_name = self.temp_table_name() + user_tmp = self.tables[table_name] + + reflected_user_tmp = Table( + table_name, MetaData(), autoload_with=connection ) + self.assert_tables_equal( + user_tmp, reflected_user_tmp, strict_constraints=False + ) + + @testing.requires.temp_table_reflection + def test_get_temp_table_columns(self, connection): + table_name = self.temp_table_name() user_tmp = self.tables[table_name] - insp = inspect(self.bind) + insp = inspect(connection) cols = insp.get_columns(table_name) - self.assert_(len(cols) > 0, len(cols)) + is_true(len(cols) > 0, len(cols)) for i, col in enumerate(user_tmp.columns): eq_(col.name, cols[i]["name"]) @@ -761,8 +1513,8 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.temp_table_reflection @testing.requires.view_column_reflection @testing.requires.temporary_views - def test_get_temp_view_columns(self): - insp = inspect(self.bind) + def test_get_temp_view_columns(self, connection): + insp = inspect(connection) cols = insp.get_columns("user_tmp_v") eq_([col["name"] for col in cols], ["id", "name", "foo"]) @@ -778,18 +1530,27 @@ class ComponentReflectionTest(fixtures.TablesTest): users, addresses = self.tables.users, self.tables.email_addresses insp = inspect(connection) + exp = self.exp_pks(schema=schema) users_cons = insp.get_pk_constraint(users.name, schema=schema) - users_pkeys = users_cons["constrained_columns"] - eq_(users_pkeys, ["user_id"]) + self._check_list( + [users_cons], [exp[(schema, users.name)]], self._required_pk_keys + ) addr_cons = insp.get_pk_constraint(addresses.name, schema=schema) - addr_pkeys = addr_cons["constrained_columns"] - eq_(addr_pkeys, ["address_id"]) + exp_cols = exp[(schema, addresses.name)]["constrained_columns"] + eq_(addr_cons["constrained_columns"], exp_cols) with testing.requires.reflects_pk_names.fail_if(): eq_(addr_cons["name"], "email_ad_pk") + no_cst = self.tables.no_constraints.name + self._check_list( + [insp.get_pk_constraint(no_cst, schema=schema)], + [exp[(schema, no_cst)]], + self._required_pk_keys, + ) + @testing.combinations( (False,), (True, testing.requires.schemas), argnames="use_schema" ) @@ -815,31 +1576,33 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(fkey1["referred_schema"], expected_schema) eq_(fkey1["referred_table"], users.name) eq_(fkey1["referred_columns"], ["user_id"]) - if testing.requires.self_referential_foreign_keys.enabled: - eq_(fkey1["constrained_columns"], ["parent_user_id"]) + eq_(fkey1["constrained_columns"], ["parent_user_id"]) # addresses addr_fkeys = insp.get_foreign_keys(addresses.name, schema=schema) fkey1 = addr_fkeys[0] with testing.requires.implicitly_named_constraints.fail_if(): - self.assert_(fkey1["name"] is not None) + is_true(fkey1["name"] is not None) eq_(fkey1["referred_schema"], expected_schema) eq_(fkey1["referred_table"], users.name) eq_(fkey1["referred_columns"], ["user_id"]) eq_(fkey1["constrained_columns"], ["remote_user_id"]) + no_cst = self.tables.no_constraints.name + eq_(insp.get_foreign_keys(no_cst, schema=schema), []) + @testing.requires.cross_schema_fk_reflection @testing.requires.schemas - def test_get_inter_schema_foreign_keys(self): + def test_get_inter_schema_foreign_keys(self, connection): local_table, remote_table, remote_table_2 = self.tables( - "%s.local_table" % self.bind.dialect.default_schema_name, + "%s.local_table" % connection.dialect.default_schema_name, "%s.remote_table" % testing.config.test_schema, "%s.remote_table_2" % testing.config.test_schema, ) - insp = inspect(self.bind) + insp = inspect(connection) local_fkeys = insp.get_foreign_keys(local_table.name) eq_(len(local_fkeys), 1) @@ -857,25 +1620,21 @@ class ComponentReflectionTest(fixtures.TablesTest): fkey2 = remote_fkeys[0] - assert fkey2["referred_schema"] in ( - None, - self.bind.dialect.default_schema_name, + is_true( + fkey2["referred_schema"] + in ( + None, + connection.dialect.default_schema_name, + ) ) eq_(fkey2["referred_table"], local_table.name) eq_(fkey2["referred_columns"], ["id"]) eq_(fkey2["constrained_columns"], ["local_id"]) - def _assert_insp_indexes(self, indexes, expected_indexes): - index_names = [d["name"] for d in indexes] - for e_index in expected_indexes: - assert e_index["name"] in index_names - index = indexes[index_names.index(e_index["name"])] - for key in e_index: - eq_(e_index[key], index[key]) - @testing.combinations( (False,), (True, testing.requires.schemas), argnames="use_schema" ) + @testing.requires.index_reflection def test_get_indexes(self, connection, use_schema): if use_schema: @@ -885,21 +1644,19 @@ class ComponentReflectionTest(fixtures.TablesTest): # The database may decide to create indexes for foreign keys, etc. # so there may be more indexes than expected. - insp = inspect(self.bind) + insp = inspect(connection) indexes = insp.get_indexes("users", schema=schema) - expected_indexes = [ - { - "unique": False, - "column_names": ["test1", "test2"], - "name": "users_t_idx", - }, - { - "unique": False, - "column_names": ["user_id", "test2", "test1"], - "name": "users_all_idx", - }, - ] - self._assert_insp_indexes(indexes, expected_indexes) + exp = self.exp_indexes(schema=schema) + self._check_list( + indexes, exp[(schema, "users")], self._required_index_keys + ) + + no_cst = self.tables.no_constraints.name + self._check_list( + insp.get_indexes(no_cst, schema=schema), + exp[(schema, no_cst)], + self._required_index_keys, + ) @testing.combinations( ("noncol_idx_test_nopk", "noncol_idx_nopk"), @@ -908,15 +1665,15 @@ class ComponentReflectionTest(fixtures.TablesTest): ) @testing.requires.index_reflection @testing.requires.indexes_with_ascdesc + @testing.requires.reflect_indexes_with_ascdesc def test_get_noncol_index(self, connection, tname, ixname): insp = inspect(connection) indexes = insp.get_indexes(tname) - # reflecting an index that has "x DESC" in it as the column. # the DB may or may not give us "x", but make sure we get the index # back, it has a name, it's connected to the table. - expected_indexes = [{"unique": False, "name": ixname}] - self._assert_insp_indexes(indexes, expected_indexes) + expected_indexes = self.exp_indexes()[(None, tname)] + self._check_list(indexes, expected_indexes, self._required_index_keys) t = Table(tname, MetaData(), autoload_with=connection) eq_(len(t.indexes), 1) @@ -925,29 +1682,17 @@ class ComponentReflectionTest(fixtures.TablesTest): @testing.requires.temp_table_reflection @testing.requires.unique_constraint_reflection - def test_get_temp_table_unique_constraints(self): - insp = inspect(self.bind) - reflected = insp.get_unique_constraints("user_tmp_%s" % config.ident) - for refl in reflected: - # Different dialects handle duplicate index and constraints - # differently, so ignore this flag - refl.pop("duplicates_index", None) - eq_( - reflected, - [ - { - "column_names": ["name"], - "name": "user_tmp_uq_%s" % config.ident, - } - ], - ) + def test_get_temp_table_unique_constraints(self, connection): + insp = inspect(connection) + name = self.temp_table_name() + reflected = insp.get_unique_constraints(name) + exp = self.exp_ucs(all_=True)[(None, name)] + self._check_list(reflected, exp, self._required_index_keys) @testing.requires.temp_table_reflect_indexes - def test_get_temp_table_indexes(self): - insp = inspect(self.bind) - table_name = get_temp_table_name( - config, config.db, "user_tmp_%s" % config.ident - ) + def test_get_temp_table_indexes(self, connection): + insp = inspect(connection) + table_name = self.temp_table_name() indexes = insp.get_indexes(table_name) for ind in indexes: ind.pop("dialect_options", None) @@ -1005,9 +1750,9 @@ class ComponentReflectionTest(fixtures.TablesTest): ) table.create(connection) - inspector = inspect(connection) + insp = inspect(connection) reflected = sorted( - inspector.get_unique_constraints("testtbl", schema=schema), + insp.get_unique_constraints("testtbl", schema=schema), key=operator.itemgetter("name"), ) @@ -1047,6 +1792,9 @@ class ComponentReflectionTest(fixtures.TablesTest): eq_(names_that_duplicate_index, idx_names) eq_(uq_names, set()) + no_cst = self.tables.no_constraints.name + eq_(insp.get_unique_constraints(no_cst, schema=schema), []) + @testing.requires.view_reflection @testing.combinations( (False,), (True, testing.requires.schemas), argnames="use_schema" @@ -1056,32 +1804,21 @@ class ComponentReflectionTest(fixtures.TablesTest): schema = config.test_schema else: schema = None - view_name1 = "users_v" - view_name2 = "email_addresses_v" insp = inspect(connection) - v1 = insp.get_view_definition(view_name1, schema=schema) - self.assert_(v1) - v2 = insp.get_view_definition(view_name2, schema=schema) - self.assert_(v2) + for view in ["users_v", "email_addresses_v", "dingalings_v"]: + v = insp.get_view_definition(view, schema=schema) + is_true(bool(v)) - # why is this here if it's PG specific ? - @testing.combinations( - ("users", False), - ("users", True, testing.requires.schemas), - argnames="table_name,use_schema", - ) - @testing.only_on("postgresql", "PG specific feature") - def test_get_table_oid(self, connection, table_name, use_schema): - if use_schema: - schema = config.test_schema - else: - schema = None + @testing.requires.view_reflection + def test_get_view_definition_does_not_exist(self, connection): insp = inspect(connection) - oid = insp.get_table_oid(table_name, schema) - self.assert_(isinstance(oid, int)) + with expect_raises(NoSuchTableError): + insp.get_view_definition("view_does_not_exist") + with expect_raises(NoSuchTableError): + insp.get_view_definition("users") # a table @testing.requires.table_reflection - def test_autoincrement_col(self): + def test_autoincrement_col(self, connection): """test that 'autoincrement' is reflected according to sqla's policy. Don't mark this test as unsupported for any backend ! @@ -1094,7 +1831,7 @@ class ComponentReflectionTest(fixtures.TablesTest): """ - insp = inspect(self.bind) + insp = inspect(connection) for tname, cname in [ ("users", "user_id"), @@ -1105,6 +1842,330 @@ class ComponentReflectionTest(fixtures.TablesTest): id_ = {c["name"]: c for c in cols}[cname] assert id_.get("autoincrement", True) + @testing.combinations( + (True, testing.requires.schemas), (False,), argnames="use_schema" + ) + def test_get_table_options(self, use_schema): + insp = inspect(config.db) + schema = config.test_schema if use_schema else None + + if testing.requires.reflect_table_options.enabled: + res = insp.get_table_options("users", schema=schema) + is_true(isinstance(res, dict)) + # NOTE: can't really create a table with no option + res = insp.get_table_options("no_constraints", schema=schema) + is_true(isinstance(res, dict)) + else: + with expect_raises(NotImplementedError): + res = insp.get_table_options("users", schema=schema) + + @testing.combinations((True, testing.requires.schemas), False) + def test_multi_get_table_options(self, use_schema): + insp = inspect(config.db) + if testing.requires.reflect_table_options.enabled: + schema = config.test_schema if use_schema else None + res = insp.get_multi_table_options(schema=schema) + + exp = { + (schema, table): insp.get_table_options(table, schema=schema) + for table in insp.get_table_names(schema=schema) + } + eq_(res, exp) + else: + with expect_raises(NotImplementedError): + res = insp.get_multi_table_options() + + @testing.fixture + def get_multi_exp(self, connection): + def provide_fixture( + schema, scope, kind, use_filter, single_reflect_fn, exp_method + ): + insp = inspect(connection) + # call the reflection function at least once to avoid + # "Unexpected success" errors if the result is actually empty + # and NotImplementedError is not raised + single_reflect_fn(insp, "email_addresses") + kw = {"scope": scope, "kind": kind} + if schema: + schema = schema() + + filter_names = [] + + if ObjectKind.TABLE in kind: + filter_names.extend( + ["comment_test", "users", "does-not-exist"] + ) + if ObjectKind.VIEW in kind: + filter_names.extend(["email_addresses_v", "does-not-exist"]) + if ObjectKind.MATERIALIZED_VIEW in kind: + filter_names.extend(["dingalings_v", "does-not-exist"]) + + if schema: + kw["schema"] = schema + if use_filter: + kw["filter_names"] = filter_names + + exp = exp_method( + schema=schema, + scope=scope, + kind=kind, + filter_names=kw.get("filter_names"), + ) + kws = [kw] + if scope == ObjectScope.DEFAULT: + nkw = kw.copy() + nkw.pop("scope") + kws.append(nkw) + if kind == ObjectKind.TABLE: + nkw = kw.copy() + nkw.pop("kind") + kws.append(nkw) + + return inspect(connection), kws, exp + + return provide_fixture + + @testing.requires.reflect_table_options + @_multi_combination + def test_multi_get_table_options_tables( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_table_options, + self.exp_options, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_table_options(**kw) + eq_(result, exp) + + @testing.requires.comment_reflection + @_multi_combination + def test_get_multi_table_comment( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_table_comment, + self.exp_comments, + ) + for kw in kws: + insp.clear_cache() + eq_(insp.get_multi_table_comment(**kw), exp) + + def _check_list(self, result, exp, req_keys=None, msg=None): + if req_keys is None: + eq_(result, exp, msg) + else: + eq_(len(result), len(exp), msg) + for r, e in zip(result, exp): + for k in set(r) | set(e): + if k in req_keys or (k in r and k in e): + eq_(r[k], e[k], f"{msg} - {k} - {r}") + + def _check_table_dict(self, result, exp, req_keys=None, make_lists=False): + eq_(set(result.keys()), set(exp.keys())) + for k in result: + r, e = result[k], exp[k] + if make_lists: + r, e = [r], [e] + self._check_list(r, e, req_keys, k) + + @_multi_combination + def test_get_multi_columns( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_columns, + self.exp_columns, + ) + + for kw in kws: + insp.clear_cache() + result = insp.get_multi_columns(**kw) + self._check_table_dict(result, exp, self._required_column_keys) + + @testing.requires.primary_key_constraint_reflection + @_multi_combination + def test_get_multi_pk_constraint( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_pk_constraint, + self.exp_pks, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_pk_constraint(**kw) + self._check_table_dict( + result, exp, self._required_pk_keys, make_lists=True + ) + + def _adjust_sort(self, result, expected, key): + if not testing.requires.implicitly_named_constraints.enabled: + for obj in [result, expected]: + for val in obj.values(): + if len(val) > 1 and any( + v.get("name") in (None, mock.ANY) for v in val + ): + val.sort(key=key) + + @testing.requires.foreign_key_constraint_reflection + @_multi_combination + def test_get_multi_foreign_keys( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_foreign_keys, + self.exp_fks, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_foreign_keys(**kw) + self._adjust_sort( + result, exp, lambda d: tuple(d["constrained_columns"]) + ) + self._check_table_dict(result, exp, self._required_fk_keys) + + @testing.requires.index_reflection + @_multi_combination + def test_get_multi_indexes( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_indexes, + self.exp_indexes, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_indexes(**kw) + self._check_table_dict(result, exp, self._required_index_keys) + + @testing.requires.unique_constraint_reflection + @_multi_combination + def test_get_multi_unique_constraints( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_unique_constraints, + self.exp_ucs, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_unique_constraints(**kw) + self._adjust_sort(result, exp, lambda d: tuple(d["column_names"])) + self._check_table_dict(result, exp, self._required_unique_cst_keys) + + @testing.requires.check_constraint_reflection + @_multi_combination + def test_get_multi_check_constraints( + self, get_multi_exp, schema, scope, kind, use_filter + ): + insp, kws, exp = get_multi_exp( + schema, + scope, + kind, + use_filter, + Inspector.get_check_constraints, + self.exp_ccs, + ) + for kw in kws: + insp.clear_cache() + result = insp.get_multi_check_constraints(**kw) + self._adjust_sort(result, exp, lambda d: tuple(d["sqltext"])) + self._check_table_dict(result, exp, self._required_cc_keys) + + @testing.combinations( + ("get_table_options", testing.requires.reflect_table_options), + "get_columns", + ( + "get_pk_constraint", + testing.requires.primary_key_constraint_reflection, + ), + ( + "get_foreign_keys", + testing.requires.foreign_key_constraint_reflection, + ), + ("get_indexes", testing.requires.index_reflection), + ( + "get_unique_constraints", + testing.requires.unique_constraint_reflection, + ), + ( + "get_check_constraints", + testing.requires.check_constraint_reflection, + ), + ("get_table_comment", testing.requires.comment_reflection), + argnames="method", + ) + def test_not_existing_table(self, method, connection): + insp = inspect(connection) + meth = getattr(insp, method) + with expect_raises(NoSuchTableError): + meth("table_does_not_exists") + + def test_unreflectable(self, connection): + mc = Inspector.get_multi_columns + + def patched(*a, **k): + ur = k.setdefault("unreflectable", {}) + ur[(None, "some_table")] = UnreflectableTableError("err") + return mc(*a, **k) + + with mock.patch.object(Inspector, "get_multi_columns", patched): + with expect_raises_message(UnreflectableTableError, "err"): + inspect(connection).reflect_table( + Table("some_table", MetaData()), None + ) + + @testing.combinations(True, False, argnames="use_schema") + @testing.combinations( + (True, testing.requires.views), False, argnames="views" + ) + def test_metadata(self, connection, use_schema, views): + m = MetaData() + schema = config.test_schema if use_schema else None + m.reflect(connection, schema=schema, views=views, resolve_fks=False) + + insp = inspect(connection) + tables = insp.get_table_names(schema) + if views: + tables += insp.get_view_names(schema) + try: + tables += insp.get_materialized_view_names(schema) + except NotImplementedError: + pass + if schema: + tables = [f"{schema}.{t}" for t in tables] + eq_(sorted(m.tables), sorted(tables)) + class TableNoColumnsTest(fixtures.TestBase): __requires__ = ("reflect_tables_no_columns",) @@ -1118,9 +2179,6 @@ class TableNoColumnsTest(fixtures.TestBase): @testing.fixture def view_no_columns(self, connection, metadata): Table("empty", metadata) - metadata.create_all(connection) - - Table("empty", metadata) event.listen( metadata, "after_create", @@ -1134,31 +2192,32 @@ class TableNoColumnsTest(fixtures.TestBase): ) metadata.create_all(connection) - @testing.requires.reflect_tables_no_columns def test_reflect_table_no_columns(self, connection, table_no_columns): t2 = Table("empty", MetaData(), autoload_with=connection) eq_(list(t2.c), []) - @testing.requires.reflect_tables_no_columns def test_get_columns_table_no_columns(self, connection, table_no_columns): - eq_(inspect(connection).get_columns("empty"), []) + insp = inspect(connection) + eq_(insp.get_columns("empty"), []) + multi = insp.get_multi_columns() + eq_(multi, {(None, "empty"): []}) - @testing.requires.reflect_tables_no_columns def test_reflect_incl_table_no_columns(self, connection, table_no_columns): m = MetaData() m.reflect(connection) assert set(m.tables).intersection(["empty"]) @testing.requires.views - @testing.requires.reflect_tables_no_columns def test_reflect_view_no_columns(self, connection, view_no_columns): t2 = Table("empty_v", MetaData(), autoload_with=connection) eq_(list(t2.c), []) @testing.requires.views - @testing.requires.reflect_tables_no_columns def test_get_columns_view_no_columns(self, connection, view_no_columns): - eq_(inspect(connection).get_columns("empty_v"), []) + insp = inspect(connection) + eq_(insp.get_columns("empty_v"), []) + multi = insp.get_multi_columns(kind=ObjectKind.VIEW) + eq_(multi, {(None, "empty_v"): []}) class ComponentReflectionTestExtra(fixtures.TestBase): @@ -1185,12 +2244,18 @@ class ComponentReflectionTestExtra(fixtures.TestBase): ), schema=schema, ) + Table( + "no_constraints", + metadata, + Column("data", sa.String(20)), + schema=schema, + ) metadata.create_all(connection) - inspector = inspect(connection) + insp = inspect(connection) reflected = sorted( - inspector.get_check_constraints("sa_cc", schema=schema), + insp.get_check_constraints("sa_cc", schema=schema), key=operator.itemgetter("name"), ) @@ -1213,6 +2278,8 @@ class ComponentReflectionTestExtra(fixtures.TestBase): {"name": "cc1", "sqltext": "a > 1 and a < 5"}, ], ) + no_cst = "no_constraints" + eq_(insp.get_check_constraints(no_cst, schema=schema), []) @testing.requires.indexes_with_expressions def test_reflect_expression_based_indexes(self, metadata, connection): @@ -1642,7 +2709,8 @@ class IdentityReflectionTest(fixtures.TablesTest): if col["name"] == "normal": is_false("identity" in col) elif col["name"] == "id1": - is_true(col["autoincrement"] in (True, "auto")) + if "autoincrement" in col: + is_true(col["autoincrement"]) eq_(col["default"], None) is_true("identity" in col) self.check( @@ -1659,7 +2727,8 @@ class IdentityReflectionTest(fixtures.TablesTest): approx=True, ) elif col["name"] == "id2": - is_true(col["autoincrement"] in (True, "auto")) + if "autoincrement" in col: + is_true(col["autoincrement"]) eq_(col["default"], None) is_true("identity" in col) self.check( @@ -1685,7 +2754,8 @@ class IdentityReflectionTest(fixtures.TablesTest): if col["name"] == "normal": is_false("identity" in col) elif col["name"] == "id1": - is_true(col["autoincrement"] in (True, "auto")) + if "autoincrement" in col: + is_true(col["autoincrement"]) eq_(col["default"], None) is_true("identity" in col) self.check( @@ -1735,16 +2805,16 @@ class CompositeKeyReflectionTest(fixtures.TablesTest): ) @testing.requires.primary_key_constraint_reflection - def test_pk_column_order(self): + def test_pk_column_order(self, connection): # test for issue #5661 - insp = inspect(self.bind) + insp = inspect(connection) primary_key = insp.get_pk_constraint(self.tables.tb1.name) eq_(primary_key.get("constrained_columns"), ["name", "id", "attr"]) @testing.requires.foreign_key_constraint_reflection - def test_fk_column_order(self): + def test_fk_column_order(self, connection): # test for issue #5661 - insp = inspect(self.bind) + insp = inspect(connection) foreign_keys = insp.get_foreign_keys(self.tables.tb2.name) eq_(len(foreign_keys), 1) fkey1 = foreign_keys[0] diff --git a/lib/sqlalchemy/testing/suite/test_sequence.py b/lib/sqlalchemy/testing/suite/test_sequence.py index eae051992..e15fad642 100644 --- a/lib/sqlalchemy/testing/suite/test_sequence.py +++ b/lib/sqlalchemy/testing/suite/test_sequence.py @@ -194,16 +194,23 @@ class HasSequenceTest(fixtures.TablesTest): ) def test_has_sequence(self, connection): - eq_( - inspect(connection).has_sequence("user_id_seq"), - True, - ) + eq_(inspect(connection).has_sequence("user_id_seq"), True) + + def test_has_sequence_cache(self, connection, metadata): + insp = inspect(connection) + eq_(insp.has_sequence("user_id_seq"), True) + ss = Sequence("new_seq", metadata=metadata) + eq_(insp.has_sequence("new_seq"), False) + ss.create(connection) + try: + eq_(insp.has_sequence("new_seq"), False) + insp.clear_cache() + eq_(insp.has_sequence("new_seq"), True) + finally: + ss.drop(connection) def test_has_sequence_other_object(self, connection): - eq_( - inspect(connection).has_sequence("user_id_table"), - False, - ) + eq_(inspect(connection).has_sequence("user_id_table"), False) @testing.requires.schemas def test_has_sequence_schema(self, connection): @@ -215,10 +222,7 @@ class HasSequenceTest(fixtures.TablesTest): ) def test_has_sequence_neg(self, connection): - eq_( - inspect(connection).has_sequence("some_sequence"), - False, - ) + eq_(inspect(connection).has_sequence("some_sequence"), False) @testing.requires.schemas def test_has_sequence_schemas_neg(self, connection): @@ -240,10 +244,7 @@ class HasSequenceTest(fixtures.TablesTest): @testing.requires.schemas def test_has_sequence_remote_not_in_default(self, connection): - eq_( - inspect(connection).has_sequence("schema_seq"), - False, - ) + eq_(inspect(connection).has_sequence("schema_seq"), False) def test_get_sequence_names(self, connection): exp = {"other_seq", "user_id_seq"} diff --git a/lib/sqlalchemy/testing/util.py b/lib/sqlalchemy/testing/util.py index 0070b4d67..6fd42af70 100644 --- a/lib/sqlalchemy/testing/util.py +++ b/lib/sqlalchemy/testing/util.py @@ -393,36 +393,55 @@ def drop_all_tables_from_metadata(metadata, engine_or_connection): go(engine_or_connection) -def drop_all_tables(engine, inspector, schema=None, include_names=None): +def drop_all_tables( + engine, + inspector, + schema=None, + consider_schemas=(None,), + include_names=None, +): if include_names is not None: include_names = set(include_names) + if schema is not None: + assert consider_schemas == ( + None, + ), "consider_schemas and schema are mutually exclusive" + consider_schemas = (schema,) + with engine.begin() as conn: - for tname, fkcs in reversed( - inspector.get_sorted_table_and_fkc_names(schema=schema) + for table_key, fkcs in reversed( + inspector.sort_tables_on_foreign_key_dependency( + consider_schemas=consider_schemas + ) ): - if tname: - if include_names is not None and tname not in include_names: + if table_key: + if ( + include_names is not None + and table_key[1] not in include_names + ): continue conn.execute( - DropTable(Table(tname, MetaData(), schema=schema)) + DropTable( + Table(table_key[1], MetaData(), schema=table_key[0]) + ) ) elif fkcs: if not engine.dialect.supports_alter: continue - for tname, fkc in fkcs: + for t_key, fkc in fkcs: if ( include_names is not None - and tname not in include_names + and t_key[1] not in include_names ): continue tb = Table( - tname, + t_key[1], MetaData(), Column("x", Integer), Column("y", Integer), - schema=schema, + schema=t_key[0], ) conn.execute( DropConstraint( diff --git a/lib/sqlalchemy/util/topological.py b/lib/sqlalchemy/util/topological.py index 24e478b57..620e3bbb7 100644 --- a/lib/sqlalchemy/util/topological.py +++ b/lib/sqlalchemy/util/topological.py @@ -10,6 +10,7 @@ from __future__ import annotations from typing import Any +from typing import Collection from typing import DefaultDict from typing import Iterable from typing import Iterator @@ -27,7 +28,7 @@ __all__ = ["sort", "sort_as_subsets", "find_cycles"] def sort_as_subsets( - tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T] + tuples: Collection[Tuple[_T, _T]], allitems: Collection[_T] ) -> Iterator[Sequence[_T]]: edges: DefaultDict[_T, Set[_T]] = util.defaultdict(set) @@ -56,8 +57,8 @@ def sort_as_subsets( def sort( - tuples: Iterable[Tuple[_T, _T]], - allitems: Iterable[_T], + tuples: Collection[Tuple[_T, _T]], + allitems: Collection[_T], deterministic_order: bool = True, ) -> Iterator[_T]: """sort the given list of items by dependency. @@ -76,8 +77,7 @@ def sort( def find_cycles( - tuples: Iterable[Tuple[_T, _T]], - allitems: Iterable[_T], + tuples: Iterable[Tuple[_T, _T]], allitems: Iterable[_T] ) -> Set[_T]: # adapted from: # https://neopythonic.blogspot.com/2009/01/detecting-cycles-in-directed-graph.html diff --git a/lib/sqlalchemy/util/typing.py b/lib/sqlalchemy/util/typing.py index eb625e06e..4e76554c7 100644 --- a/lib/sqlalchemy/util/typing.py +++ b/lib/sqlalchemy/util/typing.py @@ -78,11 +78,13 @@ if typing.TYPE_CHECKING or compat.py38: from typing import Protocol as Protocol from typing import TypedDict as TypedDict from typing import Final as Final + from typing import final as final else: from typing_extensions import Literal as Literal # noqa: F401 from typing_extensions import Protocol as Protocol # noqa: F401 from typing_extensions import TypedDict as TypedDict # noqa: F401 from typing_extensions import Final as Final # noqa: F401 + from typing_extensions import final as final # noqa: F401 typing_get_args = get_args typing_get_origin = get_origin diff --git a/pyproject.toml b/pyproject.toml index 812d60e91..3807b9c10 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,17 +60,4 @@ module = [ warn_unused_ignores = false strict = true -[[tool.mypy.overrides]] - -##################################################################### -# interim list of modules that need some level of type checking to -# pass -module = [ - - "sqlalchemy.engine.reflection", - -] - -ignore_errors = true -warn_unused_ignores = false @@ -128,6 +128,9 @@ profile_file = test/profiles.txt # create public database link test_link connect to scott identified by tiger # using 'xe'; oracle_db_link = test_link +# create public database link test_link2 connect to test_schema identified by tiger +# using 'xe'; +oracle_db_link2 = test_link2 # host name of a postgres database that has the postgres_fdw extension. # to create this run: @@ -162,8 +165,8 @@ mariadb = mariadb+mysqldb://scott:tiger@127.0.0.1:3306/test mariadb_connector = mariadb+mariadbconnector://scott:tiger@127.0.0.1:3306/test mssql = mssql+pyodbc://scott:tiger^5HHH@mssql2017:1433/test?driver=ODBC+Driver+13+for+SQL+Server mssql_pymssql = mssql+pymssql://scott:tiger@ms_2008 -docker_mssql = mssql+pymssql://scott:tiger^5HHH@127.0.0.1:1433/test +docker_mssql = mssql+pyodbc://scott:tiger^5HHH@127.0.0.1:1433/test?driver=ODBC+Driver+17+for+SQL+Server oracle = oracle+cx_oracle://scott:tiger@oracle18c/xe cxoracle = oracle+cx_oracle://scott:tiger@oracle18c/xe -oracle_oracledb = oracle+oracledb://scott:tiger@oracle18c/xe oracledb = oracle+oracledb://scott:tiger@oracle18c/xe +docker_oracle = oracle+cx_oracle://scott:tiger@127.0.0.1:1521/?service_name=XEPDB1
\ No newline at end of file diff --git a/test/dialect/mysql/test_reflection.py b/test/dialect/mysql/test_reflection.py index f414c9c37..846001347 100644 --- a/test/dialect/mysql/test_reflection.py +++ b/test/dialect/mysql/test_reflection.py @@ -33,9 +33,9 @@ from sqlalchemy import UniqueConstraint from sqlalchemy.dialects.mysql import base as mysql from sqlalchemy.dialects.mysql import reflection as _reflection from sqlalchemy.schema import CreateIndex -from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises_message from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -558,6 +558,10 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): ) def test_skip_not_describable(self, metadata, connection): + """This test is the only one that test the _default_multi_reflect + behaviour with UnreflectableTableError + """ + @event.listens_for(metadata, "before_drop") def cleanup(*arg, **kw): with testing.db.begin() as conn: @@ -579,14 +583,10 @@ class ReflectionTest(fixtures.TestBase, AssertsCompiledSQL): m.reflect(views=True, bind=conn) eq_(m.tables["test_t2"].name, "test_t2") - assert_raises_message( - exc.UnreflectableTableError, - "references invalid table", - Table, - "test_v", - MetaData(), - autoload_with=conn, - ) + with expect_raises_message( + exc.UnreflectableTableError, "references invalid table" + ): + Table("test_v", MetaData(), autoload_with=conn) @testing.exclude("mysql", "<", (5, 0, 0), "no information_schema support") def test_system_views(self): diff --git a/test/dialect/oracle/test_reflection.py b/test/dialect/oracle/test_reflection.py index bf76dca43..53eb94df3 100644 --- a/test/dialect/oracle/test_reflection.py +++ b/test/dialect/oracle/test_reflection.py @@ -27,14 +27,18 @@ from sqlalchemy.dialects.oracle.base import BINARY_FLOAT from sqlalchemy.dialects.oracle.base import DOUBLE_PRECISION from sqlalchemy.dialects.oracle.base import NUMBER from sqlalchemy.dialects.oracle.base import REAL +from sqlalchemy.engine import ObjectKind from sqlalchemy.testing import assert_warns from sqlalchemy.testing import AssertsCompiledSQL +from sqlalchemy.testing import config from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ from sqlalchemy.testing import is_true from sqlalchemy.testing.engines import testing_engine from sqlalchemy.testing.schema import Column +from sqlalchemy.testing.schema import eq_compile_type from sqlalchemy.testing.schema import Table @@ -384,6 +388,7 @@ class SystemTableTablenamesTest(fixtures.TestBase): __backend__ = True def setup_test(self): + with testing.db.begin() as conn: conn.exec_driver_sql("create table my_table (id integer)") conn.exec_driver_sql( @@ -417,6 +422,14 @@ class SystemTableTablenamesTest(fixtures.TestBase): set(["my_table", "foo_table"]), ) + def test_reflect_system_table(self): + meta = MetaData() + t = Table("foo_table", meta, autoload_with=testing.db) + assert t.columns.keys() == ["id"] + + t = Table("my_temp_table", meta, autoload_with=testing.db) + assert t.columns.keys() == ["id"] + class DontReflectIOTTest(fixtures.TestBase): """test that index overflow tables aren't included in @@ -509,6 +522,228 @@ class TableReflectionTest(fixtures.TestBase): tbl = Table("test_compress", m2, autoload_with=connection) assert tbl.dialect_options["oracle"]["compress"] == "OLTP" + def test_reflect_hidden_column(self): + with testing.db.begin() as conn: + conn.exec_driver_sql( + "CREATE TABLE my_table(id integer, hide integer INVISIBLE)" + ) + + try: + insp = inspect(conn) + cols = insp.get_columns("my_table") + assert len(cols) == 1 + assert cols[0]["name"] == "id" + finally: + conn.exec_driver_sql("DROP TABLE my_table") + + +class ViewReflectionTest(fixtures.TestBase): + __only_on__ = "oracle" + __backend__ = True + + @classmethod + def setup_test_class(cls): + sql = """ + CREATE TABLE tbl ( + id INTEGER PRIMARY KEY, + data INTEGER + ); + + CREATE VIEW tbl_plain_v AS + SELECT id, data FROM tbl WHERE id > 100; + + -- comments on plain views are created with "comment on table" + -- because why not.. + COMMENT ON TABLE tbl_plain_v IS 'view comment'; + + CREATE MATERIALIZED VIEW tbl_v AS + SELECT id, data FROM tbl WHERE id > 42; + + COMMENT ON MATERIALIZED VIEW tbl_v IS 'my mat view comment'; + + CREATE MATERIALIZED VIEW tbl_v2 AS + SELECT id, data FROM tbl WHERE id < 42; + + COMMENT ON MATERIALIZED VIEW tbl_v2 IS 'my other mat view comment'; + + CREATE SYNONYM view_syn FOR tbl_plain_v; + CREATE SYNONYM %(test_schema)s.ts_v_s FOR tbl_plain_v; + + CREATE VIEW %(test_schema)s.schema_view AS + SELECT 1 AS value FROM dual; + + COMMENT ON TABLE %(test_schema)s.schema_view IS 'schema view comment'; + CREATE SYNONYM syn_schema_view FOR %(test_schema)s.schema_view; + """ + if testing.requires.oracle_test_dblink.enabled: + cls.dblink = config.file_config.get( + "sqla_testing", "oracle_db_link" + ) + sql += """ + CREATE SYNONYM syn_link FOR tbl_plain_v@%(link)s; + """ % { + "link": cls.dblink + } + with testing.db.begin() as conn: + for stmt in ( + sql % {"test_schema": testing.config.test_schema} + ).split(";"): + if stmt.strip(): + conn.exec_driver_sql(stmt) + + @classmethod + def teardown_test_class(cls): + sql = """ + DROP MATERIALIZED VIEW tbl_v; + DROP MATERIALIZED VIEW tbl_v2; + DROP VIEW tbl_plain_v; + DROP TABLE tbl; + DROP VIEW %(test_schema)s.schema_view; + DROP SYNONYM view_syn; + DROP SYNONYM %(test_schema)s.ts_v_s; + DROP SYNONYM syn_schema_view; + """ + if testing.requires.oracle_test_dblink.enabled: + sql += """ + DROP SYNONYM syn_link; + """ + with testing.db.begin() as conn: + for stmt in ( + sql % {"test_schema": testing.config.test_schema} + ).split(";"): + if stmt.strip(): + conn.exec_driver_sql(stmt) + + def test_get_names(self, connection): + insp = inspect(connection) + eq_(insp.get_table_names(), ["tbl"]) + eq_(insp.get_view_names(), ["tbl_plain_v"]) + eq_(insp.get_materialized_view_names(), ["tbl_v", "tbl_v2"]) + eq_( + insp.get_view_names(schema=testing.config.test_schema), + ["schema_view"], + ) + + def test_get_table_comment_on_view(self, connection): + insp = inspect(connection) + eq_(insp.get_table_comment("tbl_v"), {"text": "my mat view comment"}) + eq_(insp.get_table_comment("tbl_plain_v"), {"text": "view comment"}) + + def test_get_multi_view_comment(self, connection): + insp = inspect(connection) + plain = {(None, "tbl_plain_v"): {"text": "view comment"}} + mat = { + (None, "tbl_v"): {"text": "my mat view comment"}, + (None, "tbl_v2"): {"text": "my other mat view comment"}, + } + eq_(insp.get_multi_table_comment(kind=ObjectKind.VIEW), plain) + eq_( + insp.get_multi_table_comment(kind=ObjectKind.MATERIALIZED_VIEW), + mat, + ) + eq_( + insp.get_multi_table_comment(kind=ObjectKind.ANY_VIEW), + {**plain, **mat}, + ) + ts = testing.config.test_schema + eq_( + insp.get_multi_table_comment(kind=ObjectKind.ANY_VIEW, schema=ts), + {(ts, "schema_view"): {"text": "schema view comment"}}, + ) + eq_(insp.get_multi_table_comment(), {(None, "tbl"): {"text": None}}) + + def test_get_table_comment_synonym(self, connection): + insp = inspect(connection) + eq_( + insp.get_table_comment("view_syn", oracle_resolve_synonyms=True), + {"text": "view comment"}, + ) + eq_( + insp.get_table_comment( + "syn_schema_view", oracle_resolve_synonyms=True + ), + {"text": "schema view comment"}, + ) + eq_( + insp.get_table_comment( + "ts_v_s", + oracle_resolve_synonyms=True, + schema=testing.config.test_schema, + ), + {"text": "view comment"}, + ) + + def test_get_multi_view_comment_synonym(self, connection): + insp = inspect(connection) + exp = { + (None, "view_syn"): {"text": "view comment"}, + (None, "syn_schema_view"): {"text": "schema view comment"}, + } + if testing.requires.oracle_test_dblink.enabled: + exp[(None, "syn_link")] = {"text": "view comment"} + eq_( + insp.get_multi_table_comment( + oracle_resolve_synonyms=True, kind=ObjectKind.ANY_VIEW + ), + exp, + ) + ts = testing.config.test_schema + eq_( + insp.get_multi_table_comment( + oracle_resolve_synonyms=True, + schema=ts, + kind=ObjectKind.ANY_VIEW, + ), + {(ts, "ts_v_s"): {"text": "view comment"}}, + ) + + def test_get_view_definition(self, connection): + insp = inspect(connection) + eq_( + insp.get_view_definition("tbl_plain_v"), + "SELECT id, data FROM tbl WHERE id > 100", + ) + eq_( + insp.get_view_definition("tbl_v"), + "SELECT id, data FROM tbl WHERE id > 42", + ) + with expect_raises(exc.NoSuchTableError): + eq_(insp.get_view_definition("view_syn"), None) + eq_( + insp.get_view_definition("view_syn", oracle_resolve_synonyms=True), + "SELECT id, data FROM tbl WHERE id > 100", + ) + eq_( + insp.get_view_definition( + "syn_schema_view", oracle_resolve_synonyms=True + ), + "SELECT 1 AS value FROM dual", + ) + eq_( + insp.get_view_definition( + "ts_v_s", + oracle_resolve_synonyms=True, + schema=testing.config.test_schema, + ), + "SELECT id, data FROM tbl WHERE id > 100", + ) + + @testing.requires.oracle_test_dblink + def test_get_view_definition_dblink(self, connection): + insp = inspect(connection) + eq_( + insp.get_view_definition("syn_link", oracle_resolve_synonyms=True), + "SELECT id, data FROM tbl WHERE id > 100", + ) + eq_( + insp.get_view_definition("tbl_plain_v", dblink=self.dblink), + "SELECT id, data FROM tbl WHERE id > 100", + ) + eq_( + insp.get_view_definition("tbl_v", dblink=self.dblink), + "SELECT id, data FROM tbl WHERE id > 42", + ) + class RoundTripIndexTest(fixtures.TestBase): __only_on__ = "oracle" @@ -722,8 +957,6 @@ class DBLinkReflectionTest(fixtures.TestBase): @classmethod def setup_test_class(cls): - from sqlalchemy.testing import config - cls.dblink = config.file_config.get("sqla_testing", "oracle_db_link") # note that the synonym here is still not totally functional @@ -863,3 +1096,320 @@ class IdentityReflectionTest(fixtures.TablesTest): exp = common.copy() exp["order"] = True eq_(col["identity"], exp) + + +class AdditionalReflectionTests(fixtures.TestBase): + __only_on__ = "oracle" + __backend__ = True + + @classmethod + def setup_test_class(cls): + # currently assuming full DBA privs for the user. + # don't really know how else to go here unless + # we connect as the other user. + + sql = """ +CREATE TABLE %(schema)sparent( + id INTEGER, + data VARCHAR2(50), + CONSTRAINT parent_pk_%(schema_id)s PRIMARY KEY (id) +); +CREATE TABLE %(schema)smy_table( + id INTEGER, + name VARCHAR2(125), + related INTEGER, + data%(schema_id)s NUMBER NOT NULL, + CONSTRAINT my_table_pk_%(schema_id)s PRIMARY KEY (id), + CONSTRAINT my_table_fk_%(schema_id)s FOREIGN KEY(related) + REFERENCES %(schema)sparent(id), + CONSTRAINT my_table_check_%(schema_id)s CHECK (data%(schema_id)s > 42), + CONSTRAINT data_unique%(schema_id)s UNIQUE (data%(schema_id)s) +); +CREATE INDEX my_table_index_%(schema_id)s on %(schema)smy_table (id, name); +COMMENT ON TABLE %(schema)smy_table IS 'my table comment %(schema_id)s'; +COMMENT ON COLUMN %(schema)smy_table.name IS +'my table.name comment %(schema_id)s'; +""" + + with testing.db.begin() as conn: + for schema in ("", testing.config.test_schema): + dd = { + "schema": f"{schema}." if schema else "", + "schema_id": "sch" if schema else "", + } + for stmt in (sql % dd).split(";"): + if stmt.strip(): + conn.exec_driver_sql(stmt) + + @classmethod + def teardown_test_class(cls): + sql = """ +drop table %(schema)smy_table; +drop table %(schema)sparent; +""" + with testing.db.begin() as conn: + for schema in ("", testing.config.test_schema): + dd = {"schema": f"{schema}." if schema else ""} + for stmt in (sql % dd).split(";"): + if stmt.strip(): + try: + conn.exec_driver_sql(stmt) + except: + pass + + def setup_test(self): + self.dblink = config.file_config.get("sqla_testing", "oracle_db_link") + self.dblink2 = config.file_config.get( + "sqla_testing", "oracle_db_link2" + ) + self.columns = {} + self.indexes = {} + self.primary_keys = {} + self.comments = {} + self.uniques = {} + self.checks = {} + self.foreign_keys = {} + self.options = {} + self.allDicts = [ + self.columns, + self.indexes, + self.primary_keys, + self.comments, + self.uniques, + self.checks, + self.foreign_keys, + self.options, + ] + for schema in (None, testing.config.test_schema): + suffix = "sch" if schema else "" + + self.columns[schema] = { + (schema, "my_table"): [ + { + "name": "id", + "nullable": False, + "type": eq_compile_type("INTEGER"), + "default": None, + "comment": None, + }, + { + "name": "name", + "nullable": True, + "type": eq_compile_type("VARCHAR(125)"), + "default": None, + "comment": f"my table.name comment {suffix}", + }, + { + "name": "related", + "nullable": True, + "type": eq_compile_type("INTEGER"), + "default": None, + "comment": None, + }, + { + "name": f"data{suffix}", + "nullable": False, + "type": eq_compile_type("NUMBER"), + "default": None, + "comment": None, + }, + ], + (schema, "parent"): [ + { + "name": "id", + "nullable": False, + "type": eq_compile_type("INTEGER"), + "default": None, + "comment": None, + }, + { + "name": "data", + "nullable": True, + "type": eq_compile_type("VARCHAR(50)"), + "default": None, + "comment": None, + }, + ], + } + self.indexes[schema] = { + (schema, "my_table"): [ + { + "name": f"data_unique{suffix}", + "column_names": [f"data{suffix}"], + "dialect_options": {}, + "unique": True, + }, + { + "name": f"my_table_index_{suffix}", + "column_names": ["id", "name"], + "dialect_options": {}, + "unique": False, + }, + ], + (schema, "parent"): [], + } + self.primary_keys[schema] = { + (schema, "my_table"): { + "name": f"my_table_pk_{suffix}", + "constrained_columns": ["id"], + }, + (schema, "parent"): { + "name": f"parent_pk_{suffix}", + "constrained_columns": ["id"], + }, + } + self.comments[schema] = { + (schema, "my_table"): {"text": f"my table comment {suffix}"}, + (schema, "parent"): {"text": None}, + } + self.foreign_keys[schema] = { + (schema, "my_table"): [ + { + "name": f"my_table_fk_{suffix}", + "constrained_columns": ["related"], + "referred_schema": schema, + "referred_table": "parent", + "referred_columns": ["id"], + "options": {}, + } + ], + (schema, "parent"): [], + } + self.checks[schema] = { + (schema, "my_table"): [ + { + "name": f"my_table_check_{suffix}", + "sqltext": f"data{suffix} > 42", + } + ], + (schema, "parent"): [], + } + self.uniques[schema] = { + (schema, "my_table"): [ + { + "name": f"data_unique{suffix}", + "column_names": [f"data{suffix}"], + "duplicates_index": f"data_unique{suffix}", + } + ], + (schema, "parent"): [], + } + self.options[schema] = { + (schema, "my_table"): {}, + (schema, "parent"): {}, + } + + def test_tables(self, connection): + insp = inspect(connection) + + eq_(sorted(insp.get_table_names()), ["my_table", "parent"]) + + def _check_reflection(self, conn, schema, res_schema=False, **kw): + if res_schema is False: + res_schema = schema + insp = inspect(conn) + eq_( + insp.get_multi_columns(schema=schema, **kw), + self.columns[res_schema], + ) + eq_( + insp.get_multi_indexes(schema=schema, **kw), + self.indexes[res_schema], + ) + eq_( + insp.get_multi_pk_constraint(schema=schema, **kw), + self.primary_keys[res_schema], + ) + eq_( + insp.get_multi_table_comment(schema=schema, **kw), + self.comments[res_schema], + ) + eq_( + insp.get_multi_foreign_keys(schema=schema, **kw), + self.foreign_keys[res_schema], + ) + eq_( + insp.get_multi_check_constraints(schema=schema, **kw), + self.checks[res_schema], + ) + eq_( + insp.get_multi_unique_constraints(schema=schema, **kw), + self.uniques[res_schema], + ) + eq_( + insp.get_multi_table_options(schema=schema, **kw), + self.options[res_schema], + ) + + @testing.combinations(True, False, argnames="schema") + def test_schema_translate_map(self, connection, schema): + schema = testing.config.test_schema if schema else None + c = connection.execution_options( + schema_translate_map={ + None: "foo", + testing.config.test_schema: "bar", + } + ) + self._check_reflection(c, schema) + + @testing.requires.oracle_test_dblink + def test_db_link(self, connection): + self._check_reflection(connection, schema=None, dblink=self.dblink) + self._check_reflection( + connection, + schema=testing.config.test_schema, + dblink=self.dblink, + ) + + def test_no_synonyms(self, connection): + # oracle_resolve_synonyms is ignored if there are no matching synonym + self._check_reflection( + connection, schema=None, oracle_resolve_synonyms=True + ) + connection.exec_driver_sql("CREATE SYNONYM tmp FOR parent") + for dict_ in self.allDicts: + dict_["tmp"] = {(None, "parent"): dict_[None][(None, "parent")]} + try: + self._check_reflection( + connection, + schema=None, + res_schema="tmp", + oracle_resolve_synonyms=True, + filter_names=["parent"], + ) + finally: + connection.exec_driver_sql("DROP SYNONYM tmp") + + @testing.requires.oracle_test_dblink + @testing.requires.oracle_test_dblink2 + def test_multi_dblink_synonyms(self, connection): + # oracle_resolve_synonyms handles multiple dblink at once + connection.exec_driver_sql( + f"CREATE SYNONYM s1 FOR my_table@{self.dblink}" + ) + connection.exec_driver_sql( + f"CREATE SYNONYM s2 FOR {testing.config.test_schema}." + f"my_table@{self.dblink2}" + ) + connection.exec_driver_sql("CREATE SYNONYM s3 FOR parent") + for dict_ in self.allDicts: + dict_["tmp"] = { + (None, "s1"): dict_[None][(None, "my_table")], + (None, "s2"): dict_[testing.config.test_schema][ + (testing.config.test_schema, "my_table") + ], + (None, "s3"): dict_[None][(None, "parent")], + } + fk = self.foreign_keys["tmp"][(None, "s1")][0] + fk["referred_table"] = "s3" + try: + self._check_reflection( + connection, + schema=None, + res_schema="tmp", + oracle_resolve_synonyms=True, + ) + finally: + connection.exec_driver_sql("DROP SYNONYM s1") + connection.exec_driver_sql("DROP SYNONYM s2") + connection.exec_driver_sql("DROP SYNONYM s3") diff --git a/test/dialect/postgresql/test_dialect.py b/test/dialect/postgresql/test_dialect.py index 0093eb5ba..d55aa8203 100644 --- a/test/dialect/postgresql/test_dialect.py +++ b/test/dialect/postgresql/test_dialect.py @@ -887,19 +887,12 @@ class MiscBackendTest( ) @testing.combinations( - ((8, 1), False, False), - ((8, 1), None, False), - ((11, 5), True, False), - ((11, 5), False, True), + (True, False), + (False, True), ) - def test_backslash_escapes_detection( - self, version, explicit_setting, expected - ): + def test_backslash_escapes_detection(self, explicit_setting, expected): engine = engines.testing_engine() - def _server_version(conn): - return version - if explicit_setting is not None: @event.listens_for(engine, "connect", insert=True) @@ -912,11 +905,8 @@ class MiscBackendTest( ) dbapi_connection.commit() - with mock.patch.object( - engine.dialect, "_get_server_version_info", _server_version - ): - with engine.connect(): - eq_(engine.dialect._backslash_escapes, expected) + with engine.connect(): + eq_(engine.dialect._backslash_escapes, expected) def test_dbapi_autocommit_attribute(self): """all the supported DBAPIs have an .autocommit attribute. make diff --git a/test/dialect/postgresql/test_reflection.py b/test/dialect/postgresql/test_reflection.py index cbb1809e4..00e5dc5b9 100644 --- a/test/dialect/postgresql/test_reflection.py +++ b/test/dialect/postgresql/test_reflection.py @@ -27,17 +27,21 @@ from sqlalchemy.dialects.postgresql import base as postgresql from sqlalchemy.dialects.postgresql import ExcludeConstraint from sqlalchemy.dialects.postgresql import INTEGER from sqlalchemy.dialects.postgresql import INTERVAL +from sqlalchemy.dialects.postgresql import pg_catalog from sqlalchemy.dialects.postgresql import TSRANGE +from sqlalchemy.engine import ObjectKind +from sqlalchemy.engine import ObjectScope from sqlalchemy.schema import CreateIndex from sqlalchemy.sql.schema import CheckConstraint from sqlalchemy.testing import AssertsCompiledSQL from sqlalchemy.testing import fixtures from sqlalchemy.testing import mock -from sqlalchemy.testing.assertions import assert_raises from sqlalchemy.testing.assertions import assert_warns from sqlalchemy.testing.assertions import AssertsExecutionResults from sqlalchemy.testing.assertions import eq_ +from sqlalchemy.testing.assertions import expect_raises from sqlalchemy.testing.assertions import is_ +from sqlalchemy.testing.assertions import is_false from sqlalchemy.testing.assertions import is_true @@ -231,17 +235,36 @@ class MaterializedViewReflectionTest( connection.execute(target.insert(), {"id": 89, "data": "d1"}) materialized_view = sa.DDL( - "CREATE MATERIALIZED VIEW test_mview AS " "SELECT * FROM testtable" + "CREATE MATERIALIZED VIEW test_mview AS SELECT * FROM testtable" ) plain_view = sa.DDL( - "CREATE VIEW test_regview AS " "SELECT * FROM testtable" + "CREATE VIEW test_regview AS SELECT data FROM testtable" ) sa.event.listen(testtable, "after_create", plain_view) sa.event.listen(testtable, "after_create", materialized_view) sa.event.listen( testtable, + "after_create", + sa.DDL("COMMENT ON VIEW test_regview IS 'regular view comment'"), + ) + sa.event.listen( + testtable, + "after_create", + sa.DDL( + "COMMENT ON MATERIALIZED VIEW test_mview " + "IS 'materialized view comment'" + ), + ) + sa.event.listen( + testtable, + "after_create", + sa.DDL("CREATE INDEX mat_index ON test_mview(data DESC)"), + ) + + sa.event.listen( + testtable, "before_drop", sa.DDL("DROP MATERIALIZED VIEW test_mview"), ) @@ -249,6 +272,12 @@ class MaterializedViewReflectionTest( testtable, "before_drop", sa.DDL("DROP VIEW test_regview") ) + def test_has_type(self, connection): + insp = inspect(connection) + is_true(insp.has_type("test_mview")) + is_true(insp.has_type("test_regview")) + is_true(insp.has_type("testtable")) + def test_mview_is_reflected(self, connection): metadata = MetaData() table = Table("test_mview", metadata, autoload_with=connection) @@ -265,49 +294,99 @@ class MaterializedViewReflectionTest( def test_get_view_names(self, inspect_fixture): insp, conn = inspect_fixture - eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"])) + eq_(set(insp.get_view_names()), set(["test_regview"])) - def test_get_view_names_plain(self, connection): + def test_get_materialized_view_names(self, inspect_fixture): + insp, conn = inspect_fixture + eq_(set(insp.get_materialized_view_names()), set(["test_mview"])) + + def test_get_view_names_reflection_cache_ok(self, connection): insp = inspect(connection) + eq_(set(insp.get_view_names()), set(["test_regview"])) eq_( - set(insp.get_view_names(include=("plain",))), set(["test_regview"]) + set(insp.get_materialized_view_names()), + set(["test_mview"]), + ) + eq_( + set(insp.get_view_names()).union( + insp.get_materialized_view_names() + ), + set(["test_regview", "test_mview"]), ) - def test_get_view_names_plain_string(self, connection): + def test_get_view_definition(self, connection): insp = inspect(connection) - eq_(set(insp.get_view_names(include="plain")), set(["test_regview"])) - def test_get_view_names_materialized(self, connection): - insp = inspect(connection) + def normalize(definition): + return re.sub(r"[\n\t ]+", " ", definition.strip()) + eq_( - set(insp.get_view_names(include=("materialized",))), - set(["test_mview"]), + normalize(insp.get_view_definition("test_mview")), + "SELECT testtable.id, testtable.data FROM testtable;", + ) + eq_( + normalize(insp.get_view_definition("test_regview")), + "SELECT testtable.data FROM testtable;", ) - def test_get_view_names_reflection_cache_ok(self, connection): + def test_get_view_comment(self, connection): insp = inspect(connection) eq_( - set(insp.get_view_names(include=("plain",))), set(["test_regview"]) + insp.get_table_comment("test_regview"), + {"text": "regular view comment"}, ) eq_( - set(insp.get_view_names(include=("materialized",))), - set(["test_mview"]), + insp.get_table_comment("test_mview"), + {"text": "materialized view comment"}, ) - eq_(set(insp.get_view_names()), set(["test_regview", "test_mview"])) - def test_get_view_names_empty(self, connection): + def test_get_multi_view_comment(self, connection): insp = inspect(connection) - assert_raises(ValueError, insp.get_view_names, include=()) + eq_( + insp.get_multi_table_comment(), + {(None, "testtable"): {"text": None}}, + ) + plain = {(None, "test_regview"): {"text": "regular view comment"}} + mat = {(None, "test_mview"): {"text": "materialized view comment"}} + eq_(insp.get_multi_table_comment(kind=ObjectKind.VIEW), plain) + eq_( + insp.get_multi_table_comment(kind=ObjectKind.MATERIALIZED_VIEW), + mat, + ) + eq_( + insp.get_multi_table_comment(kind=ObjectKind.ANY_VIEW), + {**plain, **mat}, + ) + eq_( + insp.get_multi_table_comment( + kind=ObjectKind.ANY_VIEW, scope=ObjectScope.TEMPORARY + ), + {}, + ) - def test_get_view_definition(self, connection): + def test_get_multi_view_indexes(self, connection): insp = inspect(connection) + eq_(insp.get_multi_indexes(), {(None, "testtable"): []}) + + exp = { + "name": "mat_index", + "unique": False, + "column_names": ["data"], + "column_sorting": {"data": ("desc",)}, + } + if connection.dialect.server_version_info >= (11, 0): + exp["include_columns"] = [] + exp["dialect_options"] = {"postgresql_include": []} + plain = {(None, "test_regview"): []} + mat = {(None, "test_mview"): [exp]} + eq_(insp.get_multi_indexes(kind=ObjectKind.VIEW), plain) + eq_(insp.get_multi_indexes(kind=ObjectKind.MATERIALIZED_VIEW), mat) + eq_(insp.get_multi_indexes(kind=ObjectKind.ANY_VIEW), {**plain, **mat}) eq_( - re.sub( - r"[\n\t ]+", - " ", - insp.get_view_definition("test_mview").strip(), + insp.get_multi_indexes( + kind=ObjectKind.ANY_VIEW, scope=ObjectScope.TEMPORARY ), - "SELECT testtable.id, testtable.data FROM testtable;", + {}, ) @@ -993,9 +1072,9 @@ class ReflectionTest( go, [ "Skipped unsupported reflection of " - "expression-based index idx1", + "expression-based index idx1 of table party", "Skipped unsupported reflection of " - "expression-based index idx3", + "expression-based index idx3 of table party", ], ) @@ -1016,7 +1095,7 @@ class ReflectionTest( metadata.create_all(connection) - ind = connection.dialect.get_indexes(connection, t1, None) + ind = connection.dialect.get_indexes(connection, t1.name, None) partial_definitions = [] for ix in ind: @@ -1337,6 +1416,9 @@ class ReflectionTest( } ], ) + is_true(inspector.has_type("mood", "test_schema")) + is_true(inspector.has_type("mood", "*")) + is_false(inspector.has_type("mood")) def test_inspect_enums(self, metadata, inspect_fixture): @@ -1345,30 +1427,49 @@ class ReflectionTest( enum_type = postgresql.ENUM( "cat", "dog", "rat", name="pet", metadata=metadata ) + enum_type.create(conn) + conn.commit() - with conn.begin(): - enum_type.create(conn) - - eq_( - inspector.get_enums(), - [ - { - "visible": True, - "labels": ["cat", "dog", "rat"], - "name": "pet", - "schema": "public", - } - ], - ) - - def test_get_table_oid(self, metadata, inspect_fixture): - - inspector, conn = inspect_fixture + res = [ + { + "visible": True, + "labels": ["cat", "dog", "rat"], + "name": "pet", + "schema": "public", + } + ] + eq_(inspector.get_enums(), res) + is_true(inspector.has_type("pet", "*")) + is_true(inspector.has_type("pet")) + is_false(inspector.has_type("pet", "test_schema")) + + enum_type.drop(conn) + conn.commit() + eq_(inspector.get_enums(), res) + is_true(inspector.has_type("pet")) + inspector.clear_cache() + eq_(inspector.get_enums(), []) + is_false(inspector.has_type("pet")) + + def test_get_table_oid(self, metadata, connection): + Table("t1", metadata, Column("col", Integer)) + Table("t1", metadata, Column("col", Integer), schema="test_schema") + metadata.create_all(connection) + insp = inspect(connection) + oid = insp.get_table_oid("t1") + oid_schema = insp.get_table_oid("t1", schema="test_schema") + is_true(isinstance(oid, int)) + is_true(isinstance(oid_schema, int)) + is_true(oid != oid_schema) - with conn.begin(): - Table("some_table", metadata, Column("q", Integer)).create(conn) + with expect_raises(exc.NoSuchTableError): + insp.get_table_oid("does_not_exist") - assert inspector.get_table_oid("some_table") is not None + metadata.tables["t1"].drop(connection) + eq_(insp.get_table_oid("t1"), oid) + insp.clear_cache() + with expect_raises(exc.NoSuchTableError): + insp.get_table_oid("t1") def test_inspect_enums_case_sensitive(self, metadata, connection): sa.event.listen( @@ -1707,77 +1808,146 @@ class ReflectionTest( ) def test_reflect_check_warning(self): - rows = [("some name", "NOTCHECK foobar")] + rows = [("foo", "some name", "NOTCHECK foobar")] conn = mock.Mock( execute=lambda *arg, **kw: mock.MagicMock( fetchall=lambda: rows, __iter__=lambda self: iter(rows) ) ) - with mock.patch.object( - testing.db.dialect, "get_table_oid", lambda *arg, **kw: 1 + with testing.expect_warnings( + "Could not parse CHECK constraint text: 'NOTCHECK foobar'" ): - with testing.expect_warnings( - "Could not parse CHECK constraint text: 'NOTCHECK foobar'" - ): - testing.db.dialect.get_check_constraints(conn, "foo") + testing.db.dialect.get_check_constraints(conn, "foo") def test_reflect_extra_newlines(self): rows = [ - ("some name", "CHECK (\n(a \nIS\n NOT\n\n NULL\n)\n)"), - ("some other name", "CHECK ((b\nIS\nNOT\nNULL))"), - ("some CRLF name", "CHECK ((c\r\n\r\nIS\r\nNOT\r\nNULL))"), - ("some name", "CHECK (c != 'hi\nim a name\n')"), + ("foo", "some name", "CHECK (\n(a \nIS\n NOT\n\n NULL\n)\n)"), + ("foo", "some other name", "CHECK ((b\nIS\nNOT\nNULL))"), + ("foo", "some CRLF name", "CHECK ((c\r\n\r\nIS\r\nNOT\r\nNULL))"), + ("foo", "some name", "CHECK (c != 'hi\nim a name\n')"), ] conn = mock.Mock( execute=lambda *arg, **kw: mock.MagicMock( fetchall=lambda: rows, __iter__=lambda self: iter(rows) ) ) - with mock.patch.object( - testing.db.dialect, "get_table_oid", lambda *arg, **kw: 1 - ): - check_constraints = testing.db.dialect.get_check_constraints( - conn, "foo" - ) - eq_( - check_constraints, - [ - { - "name": "some name", - "sqltext": "a \nIS\n NOT\n\n NULL\n", - }, - {"name": "some other name", "sqltext": "b\nIS\nNOT\nNULL"}, - { - "name": "some CRLF name", - "sqltext": "c\r\n\r\nIS\r\nNOT\r\nNULL", - }, - {"name": "some name", "sqltext": "c != 'hi\nim a name\n'"}, - ], - ) + check_constraints = testing.db.dialect.get_check_constraints( + conn, "foo" + ) + eq_( + check_constraints, + [ + { + "name": "some name", + "sqltext": "a \nIS\n NOT\n\n NULL\n", + }, + {"name": "some other name", "sqltext": "b\nIS\nNOT\nNULL"}, + { + "name": "some CRLF name", + "sqltext": "c\r\n\r\nIS\r\nNOT\r\nNULL", + }, + {"name": "some name", "sqltext": "c != 'hi\nim a name\n'"}, + ], + ) def test_reflect_with_not_valid_check_constraint(self): - rows = [("some name", "CHECK ((a IS NOT NULL)) NOT VALID")] + rows = [("foo", "some name", "CHECK ((a IS NOT NULL)) NOT VALID")] conn = mock.Mock( execute=lambda *arg, **kw: mock.MagicMock( fetchall=lambda: rows, __iter__=lambda self: iter(rows) ) ) - with mock.patch.object( - testing.db.dialect, "get_table_oid", lambda *arg, **kw: 1 - ): - check_constraints = testing.db.dialect.get_check_constraints( - conn, "foo" + check_constraints = testing.db.dialect.get_check_constraints( + conn, "foo" + ) + eq_( + check_constraints, + [ + { + "name": "some name", + "sqltext": "a IS NOT NULL", + "dialect_options": {"not_valid": True}, + } + ], + ) + + def _apply_stm(self, connection, use_map): + if use_map: + return connection.execution_options( + schema_translate_map={ + None: "foo", + testing.config.test_schema: "bar", + } ) - eq_( - check_constraints, - [ - { - "name": "some name", - "sqltext": "a IS NOT NULL", - "dialect_options": {"not_valid": True}, - } - ], + else: + return connection + + @testing.combinations(True, False, argnames="use_map") + @testing.combinations(True, False, argnames="schema") + def test_schema_translate_map(self, metadata, connection, use_map, schema): + schema = testing.config.test_schema if schema else None + Table( + "foo", + metadata, + Column("id", Integer, primary_key=True), + Column("a", Integer, index=True), + Column( + "b", + ForeignKey(f"{schema}.foo.id" if schema else "foo.id"), + unique=True, + ), + CheckConstraint("a>10", name="foo_check"), + comment="comm", + schema=schema, + ) + metadata.create_all(connection) + if use_map: + connection = connection.execution_options( + schema_translate_map={ + None: "foo", + testing.config.test_schema: "bar", + } ) + insp = inspect(connection) + eq_( + [c["name"] for c in insp.get_columns("foo", schema=schema)], + ["id", "a", "b"], + ) + eq_( + [ + i["column_names"] + for i in insp.get_indexes("foo", schema=schema) + ], + [["b"], ["a"]], + ) + eq_( + insp.get_pk_constraint("foo", schema=schema)[ + "constrained_columns" + ], + ["id"], + ) + eq_(insp.get_table_comment("foo", schema=schema), {"text": "comm"}) + eq_( + [ + f["constrained_columns"] + for f in insp.get_foreign_keys("foo", schema=schema) + ], + [["b"]], + ) + eq_( + [ + c["name"] + for c in insp.get_check_constraints("foo", schema=schema) + ], + ["foo_check"], + ) + eq_( + [ + u["column_names"] + for u in insp.get_unique_constraints("foo", schema=schema) + ], + [["b"]], + ) class CustomTypeReflectionTest(fixtures.TestBase): @@ -1804,9 +1974,23 @@ class CustomTypeReflectionTest(fixtures.TestBase): ("my_custom_type(ARG1)", ("ARG1", None)), ("my_custom_type(ARG1, ARG2)", ("ARG1", "ARG2")), ]: - column_info = dialect._get_column_info( - "colname", sch, None, False, {}, {}, "public", None, "", None + row_dict = { + "name": "colname", + "table_name": "tblname", + "format_type": sch, + "default": None, + "not_null": False, + "comment": None, + "generated": "", + "identity_options": None, + } + column_info = dialect._get_columns_info( + [row_dict], {}, {}, "public" ) + assert ("public", "tblname") in column_info + column_info = column_info[("public", "tblname")] + assert len(column_info) == 1 + column_info = column_info[0] assert isinstance(column_info["type"], self.CustomType) eq_(column_info["type"].arg1, args[0]) eq_(column_info["type"].arg2, args[1]) @@ -1951,3 +2135,64 @@ class IdentityReflectionTest(fixtures.TablesTest): exp = default.copy() exp.update(maxvalue=2**15 - 1) eq_(col["identity"], exp) + + +class TestReflectDifficultColTypes(fixtures.TablesTest): + __only_on__ = "postgresql" + __backend__ = True + + def define_tables(metadata): + Table( + "sample_table", + metadata, + Column("c1", Integer, primary_key=True), + Column("c2", Integer, unique=True), + Column("c3", Integer), + Index("sample_table_index", "c2", "c3"), + ) + + def check_int_list(self, row, key): + value = row[key] + is_true(isinstance(value, list)) + is_true(len(value) > 0) + is_true(all(isinstance(v, int) for v in value)) + + def test_pg_index(self, connection): + insp = inspect(connection) + + pgc_oid = insp.get_table_oid("sample_table") + cols = [ + col + for col in pg_catalog.pg_index.c + if testing.db.dialect.server_version_info + >= col.info.get("server_version", (0,)) + ] + + stmt = sa.select(*cols).filter_by(indrelid=pgc_oid) + rows = connection.execute(stmt).mappings().all() + is_true(len(rows) > 0) + cols = [ + col + for col in ["indkey", "indoption", "indclass", "indcollation"] + if testing.db.dialect.server_version_info + >= pg_catalog.pg_index.c[col].info.get("server_version", (0,)) + ] + for row in rows: + for col in cols: + self.check_int_list(row, col) + + def test_pg_constraint(self, connection): + insp = inspect(connection) + + pgc_oid = insp.get_table_oid("sample_table") + cols = [ + col + for col in pg_catalog.pg_constraint.c + if testing.db.dialect.server_version_info + >= col.info.get("server_version", (0,)) + ] + stmt = sa.select(*cols).filter_by(conrelid=pgc_oid) + rows = connection.execute(stmt).mappings().all() + is_true(len(rows) > 0) + for row in rows: + self.check_int_list(row, "conkey") diff --git a/test/dialect/postgresql/test_types.py b/test/dialect/postgresql/test_types.py index 266263d5f..8b6532ce5 100644 --- a/test/dialect/postgresql/test_types.py +++ b/test/dialect/postgresql/test_types.py @@ -451,11 +451,16 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): asserter.assert_( # check for table RegexSQL( - "select relname from pg_class c join pg_namespace.*", + "SELECT pg_catalog.pg_class.relname FROM pg_catalog." + "pg_class JOIN pg_catalog.pg_namespace.*", dialect="postgresql", ), # check for enum, just once - RegexSQL(r".*SELECT EXISTS ", dialect="postgresql"), + RegexSQL( + r"SELECT pg_catalog.pg_type.typname .* WHERE " + "pg_catalog.pg_type.typname = ", + dialect="postgresql", + ), RegexSQL("CREATE TYPE myenum AS ENUM .*", dialect="postgresql"), RegexSQL(r"CREATE TABLE t .*", dialect="postgresql"), ) @@ -465,11 +470,16 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): asserter.assert_( RegexSQL( - "select relname from pg_class c join pg_namespace.*", + "SELECT pg_catalog.pg_class.relname FROM pg_catalog." + "pg_class JOIN pg_catalog.pg_namespace.*", dialect="postgresql", ), RegexSQL("DROP TABLE t", dialect="postgresql"), - RegexSQL(r".*SELECT EXISTS ", dialect="postgresql"), + RegexSQL( + r"SELECT pg_catalog.pg_type.typname .* WHERE " + "pg_catalog.pg_type.typname = ", + dialect="postgresql", + ), RegexSQL("DROP TYPE myenum", dialect="postgresql"), ) @@ -691,23 +701,6 @@ class EnumTest(fixtures.TestBase, AssertsExecutionResults): connection, "fourfivesixtype" ) - def test_no_support(self, testing_engine): - def server_version_info(self): - return (8, 2) - - e = testing_engine() - dialect = e.dialect - dialect._get_server_version_info = server_version_info - - assert dialect.supports_native_enum - e.connect() - assert not dialect.supports_native_enum - - # initialize is called again on new pool - e.dispose() - e.connect() - assert not dialect.supports_native_enum - def test_reflection(self, metadata, connection): etype = Enum( "four", "five", "six", name="fourfivesixtype", metadata=metadata diff --git a/test/dialect/test_sqlite.py b/test/dialect/test_sqlite.py index fb4331998..ed9d67612 100644 --- a/test/dialect/test_sqlite.py +++ b/test/dialect/test_sqlite.py @@ -50,6 +50,7 @@ from sqlalchemy.testing import combinations from sqlalchemy.testing import config from sqlalchemy.testing import engines from sqlalchemy.testing import eq_ +from sqlalchemy.testing import expect_raises from sqlalchemy.testing import expect_warnings from sqlalchemy.testing import fixtures from sqlalchemy.testing import is_ @@ -856,27 +857,15 @@ class AttachedDBTest(fixtures.TestBase): ["foo", "bar"], ) - eq_( - [ - d["name"] - for d in insp.get_columns("nonexistent", schema="test_schema") - ], - [], - ) - eq_( - [ - d["name"] - for d in insp.get_columns("another_created", schema=None) - ], - [], - ) - eq_( - [ - d["name"] - for d in insp.get_columns("local_only", schema="test_schema") - ], - [], - ) + with expect_raises(exc.NoSuchTableError): + insp.get_columns("nonexistent", schema="test_schema") + + with expect_raises(exc.NoSuchTableError): + insp.get_columns("another_created", schema=None) + + with expect_raises(exc.NoSuchTableError): + insp.get_columns("local_only", schema="test_schema") + eq_([d["name"] for d in insp.get_columns("local_only")], ["q", "p"]) def test_table_names_present(self): diff --git a/test/engine/test_reflection.py b/test/engine/test_reflection.py index 76099e863..2f6c06ace 100644 --- a/test/engine/test_reflection.py +++ b/test/engine/test_reflection.py @@ -2,6 +2,7 @@ import unicodedata import sqlalchemy as sa from sqlalchemy import Computed +from sqlalchemy import Connection from sqlalchemy import DefaultClause from sqlalchemy import event from sqlalchemy import FetchedValue @@ -17,6 +18,7 @@ from sqlalchemy import sql from sqlalchemy import String from sqlalchemy import testing from sqlalchemy import UniqueConstraint +from sqlalchemy.engine import Inspector from sqlalchemy.testing import assert_raises from sqlalchemy.testing import assert_raises_message from sqlalchemy.testing import AssertsCompiledSQL @@ -1254,12 +1256,13 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): m2 = MetaData() t2 = Table("x", m2, autoload_with=connection) - ck = [ + cks = [ const for const in t2.constraints if isinstance(const, sa.CheckConstraint) - ][0] - + ] + eq_(len(cks), 1) + ck = cks[0] eq_regex(ck.sqltext.text, r"[\(`]*q[\)`]* > 10") eq_(ck.name, "ck1") @@ -1268,11 +1271,17 @@ class ReflectionTest(fixtures.TestBase, ComparesTables): sa.Index("x_ix", t.c.a, t.c.b) metadata.create_all(connection) - def mock_get_columns(self, connection, table_name, **kw): - return [{"name": "b", "type": Integer, "primary_key": False}] + gri = Inspector._get_reflection_info + + def mock_gri(self, *a, **kw): + res = gri(self, *a, **kw) + res.columns[(None, "x")] = [ + col for col in res.columns[(None, "x")] if col["name"] == "b" + ] + return res with testing.mock.patch.object( - connection.dialect, "get_columns", mock_get_columns + Inspector, "_get_reflection_info", mock_gri ): m = MetaData() with testing.expect_warnings( @@ -1409,38 +1418,49 @@ class CreateDropTest(fixtures.TablesTest): eq_(ua, ["users", "email_addresses"]) eq_(oi, ["orders", "items"]) - def test_checkfirst(self, connection): + def test_checkfirst(self, connection: Connection) -> None: insp = inspect(connection) + users = self.tables.users is_false(insp.has_table("users")) users.create(connection) + insp.clear_cache() is_true(insp.has_table("users")) users.create(connection, checkfirst=True) users.drop(connection) users.drop(connection, checkfirst=True) + insp.clear_cache() is_false(insp.has_table("users")) users.create(connection, checkfirst=True) users.drop(connection) - def test_createdrop(self, connection): + def test_createdrop(self, connection: Connection) -> None: insp = inspect(connection) metadata = self.tables_test_metadata + assert metadata is not None metadata.create_all(connection) is_true(insp.has_table("items")) is_true(insp.has_table("email_addresses")) metadata.create_all(connection) + insp.clear_cache() is_true(insp.has_table("items")) metadata.drop_all(connection) + insp.clear_cache() is_false(insp.has_table("items")) is_false(insp.has_table("email_addresses")) metadata.drop_all(connection) + insp.clear_cache() is_false(insp.has_table("items")) - def test_tablenames(self, connection): + def test_has_table_and_table_names(self, connection): + """establish that has_table and get_table_names are consistent w/ + each other with regard to caching + + """ metadata = self.tables_test_metadata metadata.create_all(bind=connection) insp = inspect(connection) @@ -1448,6 +1468,19 @@ class CreateDropTest(fixtures.TablesTest): # ensure all tables we created are in the list. is_true(set(insp.get_table_names()).issuperset(metadata.tables)) + assert insp.has_table("items") + assert "items" in insp.get_table_names() + + self.tables.items.drop(connection) + + # cached + assert insp.has_table("items") + assert "items" in insp.get_table_names() + + insp = inspect(connection) + assert not insp.has_table("items") + assert "items" not in insp.get_table_names() + class SchemaManipulationTest(fixtures.TestBase): __backend__ = True @@ -1602,13 +1635,7 @@ class SchemaTest(fixtures.TestBase): __backend__ = True @testing.requires.schemas - @testing.requires.cross_schema_fk_reflection def test_has_schema(self): - if not hasattr(testing.db.dialect, "has_schema"): - testing.config.skip_test( - "dialect %s doesn't have a has_schema method" - % testing.db.dialect.name - ) with testing.db.connect() as conn: eq_( testing.db.dialect.has_schema( diff --git a/test/perf/many_table_reflection.py b/test/perf/many_table_reflection.py new file mode 100644 index 000000000..8749df5c2 --- /dev/null +++ b/test/perf/many_table_reflection.py @@ -0,0 +1,617 @@ +from argparse import ArgumentDefaultsHelpFormatter +from argparse import ArgumentParser +from collections import defaultdict +from contextlib import contextmanager +from functools import wraps +from pprint import pprint +import random +import time + +import sqlalchemy as sa +from sqlalchemy.engine import Inspector + +types = (sa.Integer, sa.BigInteger, sa.String(200), sa.DateTime) +USE_CONNECTION = False + + +def generate_table(meta: sa.MetaData, min_cols, max_cols, dialect_name): + col_number = random.randint(min_cols, max_cols) + table_num = len(meta.tables) + add_identity = random.random() > 0.90 + identity = sa.Identity( + always=random.randint(0, 1), + start=random.randint(1, 100), + increment=random.randint(1, 7), + ) + is_mssql = dialect_name == "mssql" + cols = [] + for i in range(col_number - (0 if is_mssql else add_identity)): + args = [] + if random.random() < 0.95 or table_num == 0: + if is_mssql and add_identity and i == 0: + args.append(sa.Integer) + args.append(identity) + else: + args.append(random.choice(types)) + else: + args.append( + sa.ForeignKey(f"table_{table_num-1}.table_{table_num-1}_col_1") + ) + cols.append( + sa.Column( + f"table_{table_num}_col_{i+1}", + *args, + primary_key=i == 0, + comment=f"primary key of table_{table_num}" + if i == 0 + else None, + index=random.random() > 0.9 and i > 0, + unique=random.random() > 0.95 and i > 0, + ) + ) + if add_identity and not is_mssql: + cols.append( + sa.Column( + f"table_{table_num}_col_{col_number}", + sa.Integer, + identity, + ) + ) + args = () + if table_num % 3 == 0: + # mysql can't do check constraint on PK col + args = (sa.CheckConstraint(cols[1].is_not(None)),) + return sa.Table( + f"table_{table_num}", + meta, + *cols, + *args, + comment=f"comment for table_{table_num}" if table_num % 2 else None, + ) + + +def generate_meta(schema_name, table_number, min_cols, max_cols, dialect_name): + meta = sa.MetaData(schema=schema_name) + log = defaultdict(int) + for _ in range(table_number): + t = generate_table(meta, min_cols, max_cols, dialect_name) + log["tables"] += 1 + log["columns"] += len(t.columns) + log["index"] += len(t.indexes) + log["check_con"] += len( + [c for c in t.constraints if isinstance(c, sa.CheckConstraint)] + ) + log["foreign_keys_con"] += len( + [ + c + for c in t.constraints + if isinstance(c, sa.ForeignKeyConstraint) + ] + ) + log["unique_con"] += len( + [c for c in t.constraints if isinstance(c, sa.UniqueConstraint)] + ) + log["identity"] += len([c for c in t.columns if c.identity]) + + print("Meta info", dict(log)) + return meta + + +def log(fn): + @wraps(fn) + def wrap(*a, **kw): + print("Running ", fn.__name__, "...", flush=True, end="") + try: + r = fn(*a, **kw) + except NotImplementedError: + print(" [not implemented]", flush=True) + r = None + else: + print("... done", flush=True) + return r + + return wrap + + +tests = {} + + +def define_test(fn): + name: str = fn.__name__ + if name.startswith("reflect_"): + name = name[8:] + tests[name] = wfn = log(fn) + return wfn + + +@log +def create_tables(engine, meta): + tables = list(meta.tables.values()) + for i in range(0, len(tables), 500): + meta.create_all(engine, tables[i : i + 500]) + + +@log +def drop_tables(engine, meta, schema_name, table_names: list): + tables = list(meta.tables.values())[::-1] + for i in range(0, len(tables), 500): + meta.drop_all(engine, tables[i : i + 500]) + + remaining = sa.inspect(engine).get_table_names(schema=schema_name) + suffix = "" + if engine.dialect.name.startswith("postgres"): + suffix = "CASCADE" + + remaining = sorted( + remaining, key=lambda tn: int(tn.partition("_")[2]), reverse=True + ) + with engine.connect() as conn: + for i, tn in enumerate(remaining): + if engine.dialect.requires_name_normalize: + name = engine.dialect.denormalize_name(tn) + else: + name = tn + if schema_name: + conn.execute( + sa.schema.DDL( + f'DROP TABLE {schema_name}."{name}" {suffix}' + ) + ) + else: + conn.execute(sa.schema.DDL(f'DROP TABLE "{name}" {suffix}')) + if i % 500 == 0: + conn.commit() + conn.commit() + + +@log +def reflect_tables(engine, schema_name): + ref_meta = sa.MetaData(schema=schema_name) + ref_meta.reflect(engine) + + +def verify_dict(multi, single, str_compare=False): + if single is None or multi is None: + return + if single != multi: + keys = set(single) | set(multi) + diff = [] + for key in sorted(keys): + se, me = single.get(key), multi.get(key) + if str(se) != str(me) if str_compare else se != me: + diff.append((key, single.get(key), multi.get(key))) + if diff: + print("\nfound different result:") + pprint(diff) + + +def _single_test( + singe_fn_name, + multi_fn_name, + engine, + schema_name, + table_names, + timing, + mode, +): + single = None + if "single" in mode: + singe_fn = getattr(Inspector, singe_fn_name) + + def go(bind): + insp = sa.inspect(bind) + single = {} + with timing(singe_fn.__name__): + for t in table_names: + single[(schema_name, t)] = singe_fn( + insp, t, schema=schema_name + ) + return single + + if USE_CONNECTION: + with engine.connect() as c: + single = go(c) + else: + single = go(engine) + + multi = None + if "multi" in mode: + insp = sa.inspect(engine) + multi_fn = getattr(Inspector, multi_fn_name) + with timing(multi_fn.__name__): + multi = multi_fn(insp, schema=schema_name) + return (multi, single) + + +@define_test +def reflect_columns( + engine, schema_name, table_names, timing, mode, ignore_diff +): + multi, single = _single_test( + "get_columns", + "get_multi_columns", + engine, + schema_name, + table_names, + timing, + mode, + ) + if not ignore_diff: + verify_dict(multi, single, str_compare=True) + + +@define_test +def reflect_table_options( + engine, schema_name, table_names, timing, mode, ignore_diff +): + multi, single = _single_test( + "get_table_options", + "get_multi_table_options", + engine, + schema_name, + table_names, + timing, + mode, + ) + if not ignore_diff: + verify_dict(multi, single) + + +@define_test +def reflect_pk(engine, schema_name, table_names, timing, mode, ignore_diff): + multi, single = _single_test( + "get_pk_constraint", + "get_multi_pk_constraint", + engine, + schema_name, + table_names, + timing, + mode, + ) + if not ignore_diff: + verify_dict(multi, single) + + +@define_test +def reflect_comment( + engine, schema_name, table_names, timing, mode, ignore_diff +): + multi, single = _single_test( + "get_table_comment", + "get_multi_table_comment", + engine, + schema_name, + table_names, + timing, + mode, + ) + if not ignore_diff: + verify_dict(multi, single) + + +@define_test +def reflect_whole_tables( + engine, schema_name, table_names, timing, mode, ignore_diff +): + single = None + meta = sa.MetaData(schema=schema_name) + + if "single" in mode: + + def go(bind): + single = {} + with timing("Table_autoload_with"): + for name in table_names: + single[(None, name)] = sa.Table( + name, meta, autoload_with=bind + ) + return single + + if USE_CONNECTION: + with engine.connect() as c: + single = go(c) + else: + single = go(engine) + + multi_meta = sa.MetaData(schema=schema_name) + if "multi" in mode: + with timing("MetaData_reflect"): + multi_meta.reflect(engine, only=table_names) + return (multi_meta, single) + + +@define_test +def reflect_check_constraints( + engine, schema_name, table_names, timing, mode, ignore_diff +): + multi, single = _single_test( + "get_check_constraints", + "get_multi_check_constraints", + engine, + schema_name, + table_names, + timing, + mode, + ) + if not ignore_diff: + verify_dict(multi, single) + + +@define_test +def reflect_indexes( + engine, schema_name, table_names, timing, mode, ignore_diff +): + multi, single = _single_test( + "get_indexes", + "get_multi_indexes", + engine, + schema_name, + table_names, + timing, + mode, + ) + if not ignore_diff: + verify_dict(multi, single) + + +@define_test +def reflect_foreign_keys( + engine, schema_name, table_names, timing, mode, ignore_diff +): + multi, single = _single_test( + "get_foreign_keys", + "get_multi_foreign_keys", + engine, + schema_name, + table_names, + timing, + mode, + ) + if not ignore_diff: + verify_dict(multi, single) + + +@define_test +def reflect_unique_constraints( + engine, schema_name, table_names, timing, mode, ignore_diff +): + multi, single = _single_test( + "get_unique_constraints", + "get_multi_unique_constraints", + engine, + schema_name, + table_names, + timing, + mode, + ) + if not ignore_diff: + verify_dict(multi, single) + + +def _apply_events(engine): + queries = defaultdict(list) + + now = 0 + + @sa.event.listens_for(engine, "before_cursor_execute") + def before_cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): + + nonlocal now + now = time.time() + + @sa.event.listens_for(engine, "after_cursor_execute") + def after_cursor_execute( + conn, cursor, statement, parameters, context, executemany + ): + total = time.time() - now + + if context and context.compiled: + statement_str = context.compiled.string + else: + statement_str = statement + queries[statement_str].append(total) + + return queries + + +def _print_query_stats(queries): + number_of_queries = sum( + len(query_times) for query_times in queries.values() + ) + print("-" * 50) + q_list = list(queries.items()) + q_list.sort(key=lambda rec: -sum(rec[1])) + total = sum([sum(t) for _, t in q_list]) + print(f"total number of queries: {number_of_queries}. Total time {total}") + print("-" * 50) + + for stmt, times in q_list: + total_t = sum(times) + max_t = max(times) + min_t = min(times) + avg_t = total_t / len(times) + times.sort() + median_t = times[len(times) // 2] + + print( + f"Query times: {total_t=}, {max_t=}, {min_t=}, {avg_t=}, " + f"{median_t=} Number of calls: {len(times)}" + ) + print(stmt.strip(), "\n") + + +def main(db, schema_name, table_number, min_cols, max_cols, args): + timing = timer() + if args.pool_class: + engine = sa.create_engine( + db, echo=args.echo, poolclass=getattr(sa.pool, args.pool_class) + ) + else: + engine = sa.create_engine(db, echo=args.echo) + + if engine.name == "oracle": + # clear out oracle caches so that we get the real-world time the + # queries would normally take for scripts that aren't run repeatedly + with engine.connect() as conn: + # https://stackoverflow.com/questions/2147456/how-to-clear-all-cached-items-in-oracle + conn.exec_driver_sql("alter system flush buffer_cache") + conn.exec_driver_sql("alter system flush shared_pool") + if not args.no_create: + print( + f"Generating {table_number} using engine {engine} in " + f"schema {schema_name or 'default'}", + ) + meta = sa.MetaData() + table_names = [] + stats = {} + try: + if not args.no_create: + with timing("populate-meta"): + meta = generate_meta( + schema_name, table_number, min_cols, max_cols, engine.name + ) + with timing("create-tables"): + create_tables(engine, meta) + + with timing("get_table_names"): + with engine.connect() as conn: + table_names = engine.dialect.get_table_names( + conn, schema=schema_name + ) + print( + f"Reflected table number {len(table_names)} in " + f"schema {schema_name or 'default'}" + ) + mode = {"single", "multi"} + if args.multi_only: + mode.discard("single") + if args.single_only: + mode.discard("multi") + + if args.sqlstats: + print("starting stats for subsequent tests") + stats = _apply_events(engine) + for test_name, test_fn in tests.items(): + if test_name in args.test or "all" in args.test: + test_fn( + engine, + schema_name, + table_names, + timing, + mode, + args.ignore_diff, + ) + + if args.reflect: + with timing("reflect-tables"): + reflect_tables(engine, schema_name) + finally: + # copy stats to new dict + if args.sqlstats: + stats = dict(stats) + try: + if not args.no_drop: + with timing("drop-tables"): + drop_tables(engine, meta, schema_name, table_names) + finally: + pprint(timing.timing, sort_dicts=False) + if args.sqlstats: + _print_query_stats(stats) + + +def timer(): + timing = {} + + @contextmanager + def track_time(name): + s = time.time() + yield + timing[name] = time.time() - s + + track_time.timing = timing + return track_time + + +if __name__ == "__main__": + parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter) + parser.add_argument( + "--db", help="Database url", default="sqlite:///many-table.db" + ) + parser.add_argument( + "--schema-name", + help="optional schema name", + type=str, + default=None, + ) + parser.add_argument( + "--table-number", + help="Number of table to generate.", + type=int, + default=250, + ) + parser.add_argument( + "--min-cols", + help="Min number of column per table.", + type=int, + default=15, + ) + parser.add_argument( + "--max-cols", + help="Max number of column per table.", + type=int, + default=250, + ) + parser.add_argument( + "--no-create", help="Do not run create tables", action="store_true" + ) + parser.add_argument( + "--no-drop", help="Do not run drop tables", action="store_true" + ) + parser.add_argument("--reflect", help="Run reflect", action="store_true") + parser.add_argument( + "--test", + help="Run these tests. 'all' runs all tests", + nargs="+", + choices=tuple(tests) + ("all", "none"), + default=["all"], + ) + parser.add_argument( + "--sqlstats", + help="count and time individual queries", + action="store_true", + ) + parser.add_argument( + "--multi-only", help="Only run multi table tests", action="store_true" + ) + parser.add_argument( + "--single-only", + help="Only run single table tests", + action="store_true", + ) + parser.add_argument( + "--echo", action="store_true", help="Enable echo on the engine" + ) + parser.add_argument( + "--ignore-diff", + action="store_true", + help="Ignores differences in the single/multi reflections", + ) + parser.add_argument( + "--single-inspect-conn", + action="store_true", + help="Uses inspect on a connection instead of on the engine when " + "using single reflections. Mainly for sqlite.", + ) + parser.add_argument("--pool-class", help="The pool class to use") + + args = parser.parse_args() + min_cols = args.min_cols + max_cols = args.max_cols + USE_CONNECTION = args.single_inspect_conn + assert min_cols <= max_cols and min_cols >= 1 + assert not (args.multi_only and args.single_only) + main( + args.db, args.schema_name, args.table_number, min_cols, max_cols, args + ) diff --git a/test/requirements.py b/test/requirements.py index 2d0876158..6bea3ddc9 100644 --- a/test/requirements.py +++ b/test/requirements.py @@ -76,6 +76,18 @@ class DefaultRequirements(SuiteRequirements): return skip_if(no_support("sqlite", "not supported by database")) @property + def foreign_keys_reflect_as_index(self): + return only_on(["mysql", "mariadb"]) + + @property + def unique_index_reflect_as_unique_constraints(self): + return only_on(["mysql", "mariadb"]) + + @property + def unique_constraints_reflect_as_index(self): + return only_on(["mysql", "mariadb", "oracle", "postgresql", "mssql"]) + + @property def foreign_key_constraint_name_reflection(self): return fails_if( lambda config: against(config, ["mysql", "mariadb"]) @@ -84,6 +96,10 @@ class DefaultRequirements(SuiteRequirements): ) @property + def reflect_indexes_with_ascdesc(self): + return fails_if(["oracle"]) + + @property def table_ddl_if_exists(self): """target platform supports IF NOT EXISTS / IF EXISTS for tables.""" @@ -508,6 +524,12 @@ class DefaultRequirements(SuiteRequirements): return exclusions.open() @property + def schema_create_delete(self): + """target database supports schema create and dropped with + 'CREATE SCHEMA' and 'DROP SCHEMA'""" + return exclusions.skip_if(["sqlite", "oracle"]) + + @property def cross_schema_fk_reflection(self): """target system must support reflection of inter-schema foreign keys""" @@ -547,11 +569,13 @@ class DefaultRequirements(SuiteRequirements): @property def check_constraint_reflection(self): - return fails_on_everything_except( - "postgresql", - "sqlite", - "oracle", - self._mysql_and_check_constraints_exist, + return only_on( + [ + "postgresql", + "sqlite", + "oracle", + self._mysql_and_check_constraints_exist, + ] ) @property @@ -562,7 +586,9 @@ class DefaultRequirements(SuiteRequirements): def temp_table_names(self): """target dialect supports listing of temporary table names""" - return only_on(["sqlite", "oracle"]) + skip_if(self._sqlite_file_db) + return only_on(["sqlite", "oracle", "postgresql"]) + skip_if( + self._sqlite_file_db + ) @property def temporary_views(self): @@ -792,8 +818,7 @@ class DefaultRequirements(SuiteRequirements): @property def views(self): """Target database must support VIEWs.""" - - return skip_if("drizzle", "no VIEW support") + return exclusions.open() @property def empty_strings_varchar(self): @@ -1336,14 +1361,28 @@ class DefaultRequirements(SuiteRequirements): ) ) + def _has_oracle_test_dblink(self, key): + def check(config): + assert config.db.dialect.name == "oracle" + name = config.file_config.get("sqla_testing", key) + if not name: + return False + with config.db.connect() as conn: + links = config.db.dialect._list_dblinks(conn) + return config.db.dialect.normalize_name(name) in links + + return only_on(["oracle"]) + only_if( + check, + f"{key} option not specified in config or dblink not found in db", + ) + @property def oracle_test_dblink(self): - return skip_if( - lambda config: not config.file_config.has_option( - "sqla_testing", "oracle_db_link" - ), - "oracle_db_link option not specified in config", - ) + return self._has_oracle_test_dblink("oracle_db_link") + + @property + def oracle_test_dblink2(self): + return self._has_oracle_test_dblink("oracle_db_link2") @property def postgresql_test_dblink(self): @@ -1782,6 +1821,19 @@ class DefaultRequirements(SuiteRequirements): return only_on(["mssql"]) + only_if(check) @property + def reflect_table_options(self): + return only_on(["mysql", "mariadb", "oracle"]) + + @property + def materialized_views(self): + """Target database must support MATERIALIZED VIEWs.""" + return only_on(["postgresql", "oracle"]) + + @property + def materialized_views_reflect_pk(self): + return only_on(["oracle"]) + + @property def uuid_data_type(self): """Return databases that support the UUID datatype.""" return only_on(("postgresql >= 8.3", "mariadb >= 10.7.0")) |
