diff options
Diffstat (limited to 'lib/sqlalchemy/sql/ddl.py')
| -rw-r--r-- | lib/sqlalchemy/sql/ddl.py | 268 |
1 files changed, 140 insertions, 128 deletions
diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py index 08d1072c7..3c7c674f5 100644 --- a/lib/sqlalchemy/sql/ddl.py +++ b/lib/sqlalchemy/sql/ddl.py @@ -13,6 +13,7 @@ to invoke them for a create/drop call. """ from __future__ import annotations +import contextlib import typing from typing import Any from typing import Callable @@ -538,7 +539,7 @@ class CreateTable(_CreateDropBase): .. versionadded:: 1.4.0b2 """ - super(CreateTable, self).__init__(element, if_not_exists=if_not_exists) + super().__init__(element, if_not_exists=if_not_exists) self.columns = [CreateColumn(column) for column in element.columns] self.include_foreign_key_constraints = include_foreign_key_constraints @@ -685,7 +686,7 @@ class DropTable(_CreateDropBase): .. versionadded:: 1.4.0b2 """ - super(DropTable, self).__init__(element, if_exists=if_exists) + super().__init__(element, if_exists=if_exists) class CreateSequence(_CreateDropBase): @@ -717,7 +718,7 @@ class CreateIndex(_CreateDropBase): .. versionadded:: 1.4.0b2 """ - super(CreateIndex, self).__init__(element, if_not_exists=if_not_exists) + super().__init__(element, if_not_exists=if_not_exists) class DropIndex(_CreateDropBase): @@ -737,7 +738,7 @@ class DropIndex(_CreateDropBase): .. versionadded:: 1.4.0b2 """ - super(DropIndex, self).__init__(element, if_exists=if_exists) + super().__init__(element, if_exists=if_exists) class AddConstraint(_CreateDropBase): @@ -746,7 +747,7 @@ class AddConstraint(_CreateDropBase): __visit_name__ = "add_constraint" def __init__(self, element, *args, **kw): - super(AddConstraint, self).__init__(element, *args, **kw) + super().__init__(element, *args, **kw) element._create_rule = util.portable_instancemethod( self._create_rule_disable ) @@ -759,7 +760,7 @@ class DropConstraint(_CreateDropBase): def __init__(self, element, cascade=False, **kw): self.cascade = cascade - super(DropConstraint, self).__init__(element, **kw) + super().__init__(element, **kw) element._create_rule = util.portable_instancemethod( self._create_rule_disable ) @@ -809,12 +810,49 @@ class InvokeDDLBase(SchemaVisitor): def __init__(self, connection): self.connection = connection + @contextlib.contextmanager + def with_ddl_events(self, target, **kw): + """helper context manager that will apply appropriate DDL events + to a CREATE or DROP operation.""" -class SchemaGenerator(InvokeDDLBase): + raise NotImplementedError() + + +class InvokeCreateDDLBase(InvokeDDLBase): + @contextlib.contextmanager + def with_ddl_events(self, target, **kw): + """helper context manager that will apply appropriate DDL events + to a CREATE or DROP operation.""" + + target.dispatch.before_create( + target, self.connection, _ddl_runner=self, **kw + ) + yield + target.dispatch.after_create( + target, self.connection, _ddl_runner=self, **kw + ) + + +class InvokeDropDDLBase(InvokeDDLBase): + @contextlib.contextmanager + def with_ddl_events(self, target, **kw): + """helper context manager that will apply appropriate DDL events + to a CREATE or DROP operation.""" + + target.dispatch.before_drop( + target, self.connection, _ddl_runner=self, **kw + ) + yield + target.dispatch.after_drop( + target, self.connection, _ddl_runner=self, **kw + ) + + +class SchemaGenerator(InvokeCreateDDLBase): def __init__( self, dialect, connection, checkfirst=False, tables=None, **kwargs ): - super(SchemaGenerator, self).__init__(connection, **kwargs) + super().__init__(connection, **kwargs) self.checkfirst = checkfirst self.tables = tables self.preparer = dialect.identifier_preparer @@ -871,36 +909,26 @@ class SchemaGenerator(InvokeDDLBase): ] event_collection = [t for (t, fks) in collection if t is not None] - metadata.dispatch.before_create( - metadata, - self.connection, - tables=event_collection, - checkfirst=self.checkfirst, - _ddl_runner=self, - ) - - for seq in seq_coll: - self.traverse_single(seq, create_ok=True) - - for table, fkcs in collection: - if table is not None: - self.traverse_single( - table, - create_ok=True, - include_foreign_key_constraints=fkcs, - _is_metadata_operation=True, - ) - else: - for fkc in fkcs: - self.traverse_single(fkc) - metadata.dispatch.after_create( + with self.with_ddl_events( metadata, - self.connection, tables=event_collection, checkfirst=self.checkfirst, - _ddl_runner=self, - ) + ): + for seq in seq_coll: + self.traverse_single(seq, create_ok=True) + + for table, fkcs in collection: + if table is not None: + self.traverse_single( + table, + create_ok=True, + include_foreign_key_constraints=fkcs, + _is_metadata_operation=True, + ) + else: + for fkc in fkcs: + self.traverse_single(fkc) def visit_table( self, @@ -912,75 +940,74 @@ class SchemaGenerator(InvokeDDLBase): if not create_ok and not self._can_create_table(table): return - table.dispatch.before_create( + with self.with_ddl_events( table, - self.connection, checkfirst=self.checkfirst, - _ddl_runner=self, _is_metadata_operation=_is_metadata_operation, - ) - - for column in table.columns: - if column.default is not None: - self.traverse_single(column.default) + ): - if not self.dialect.supports_alter: - # e.g., don't omit any foreign key constraints - include_foreign_key_constraints = None + for column in table.columns: + if column.default is not None: + self.traverse_single(column.default) - CreateTable( - table, - include_foreign_key_constraints=include_foreign_key_constraints, - )._invoke_with(self.connection) + if not self.dialect.supports_alter: + # e.g., don't omit any foreign key constraints + include_foreign_key_constraints = None - if hasattr(table, "indexes"): - for index in table.indexes: - self.traverse_single(index, create_ok=True) + CreateTable( + table, + include_foreign_key_constraints=( + include_foreign_key_constraints + ), + )._invoke_with(self.connection) - if self.dialect.supports_comments and not self.dialect.inline_comments: - if table.comment is not None: - SetTableComment(table)._invoke_with(self.connection) + if hasattr(table, "indexes"): + for index in table.indexes: + self.traverse_single(index, create_ok=True) - for column in table.columns: - if column.comment is not None: - SetColumnComment(column)._invoke_with(self.connection) + if ( + self.dialect.supports_comments + and not self.dialect.inline_comments + ): + if table.comment is not None: + SetTableComment(table)._invoke_with(self.connection) - if self.dialect.supports_constraint_comments: - for constraint in table.constraints: - if constraint.comment is not None: - self.connection.execute( - SetConstraintComment(constraint) - ) + for column in table.columns: + if column.comment is not None: + SetColumnComment(column)._invoke_with(self.connection) - table.dispatch.after_create( - table, - self.connection, - checkfirst=self.checkfirst, - _ddl_runner=self, - _is_metadata_operation=_is_metadata_operation, - ) + if self.dialect.supports_constraint_comments: + for constraint in table.constraints: + if constraint.comment is not None: + self.connection.execute( + SetConstraintComment(constraint) + ) def visit_foreign_key_constraint(self, constraint): if not self.dialect.supports_alter: return - AddConstraint(constraint)._invoke_with(self.connection) + + with self.with_ddl_events(constraint): + AddConstraint(constraint)._invoke_with(self.connection) def visit_sequence(self, sequence, create_ok=False): if not create_ok and not self._can_create_sequence(sequence): return - CreateSequence(sequence)._invoke_with(self.connection) + with self.with_ddl_events(sequence): + CreateSequence(sequence)._invoke_with(self.connection) def visit_index(self, index, create_ok=False): if not create_ok and not self._can_create_index(index): return - CreateIndex(index)._invoke_with(self.connection) + with self.with_ddl_events(index): + CreateIndex(index)._invoke_with(self.connection) -class SchemaDropper(InvokeDDLBase): +class SchemaDropper(InvokeDropDDLBase): def __init__( self, dialect, connection, checkfirst=False, tables=None, **kwargs ): - super(SchemaDropper, self).__init__(connection, **kwargs) + super().__init__(connection, **kwargs) self.checkfirst = checkfirst self.tables = tables self.preparer = dialect.identifier_preparer @@ -1043,36 +1070,26 @@ class SchemaDropper(InvokeDDLBase): event_collection = [t for (t, fks) in collection if t is not None] - metadata.dispatch.before_drop( + with self.with_ddl_events( metadata, - self.connection, tables=event_collection, checkfirst=self.checkfirst, - _ddl_runner=self, - ) + ): - for table, fkcs in collection: - if table is not None: - self.traverse_single( - table, - drop_ok=True, - _is_metadata_operation=True, - _ignore_sequences=seq_coll, - ) - else: - for fkc in fkcs: - self.traverse_single(fkc) + for table, fkcs in collection: + if table is not None: + self.traverse_single( + table, + drop_ok=True, + _is_metadata_operation=True, + _ignore_sequences=seq_coll, + ) + else: + for fkc in fkcs: + self.traverse_single(fkc) - for seq in seq_coll: - self.traverse_single(seq, drop_ok=seq.column is None) - - metadata.dispatch.after_drop( - metadata, - self.connection, - tables=event_collection, - checkfirst=self.checkfirst, - _ddl_runner=self, - ) + for seq in seq_coll: + self.traverse_single(seq, drop_ok=seq.column is None) def _can_drop_table(self, table): self.dialect.validate_identifier(table.name) @@ -1110,7 +1127,8 @@ class SchemaDropper(InvokeDDLBase): if not drop_ok and not self._can_drop_index(index): return - DropIndex(index)(index, self.connection) + with self.with_ddl_events(index): + DropIndex(index)(index, self.connection) def visit_table( self, @@ -1122,46 +1140,40 @@ class SchemaDropper(InvokeDDLBase): if not drop_ok and not self._can_drop_table(table): return - table.dispatch.before_drop( + with self.with_ddl_events( table, - self.connection, checkfirst=self.checkfirst, - _ddl_runner=self, _is_metadata_operation=_is_metadata_operation, - ) - - DropTable(table)._invoke_with(self.connection) + ): - # traverse client side defaults which may refer to server-side - # sequences. noting that some of these client side defaults may also be - # set up as server side defaults (see https://docs.sqlalchemy.org/en/ - # latest/core/defaults.html#associating-a-sequence-as-the-server-side- - # default), so have to be dropped after the table is dropped. - for column in table.columns: - if ( - column.default is not None - and column.default not in _ignore_sequences - ): - self.traverse_single(column.default) + DropTable(table)._invoke_with(self.connection) - table.dispatch.after_drop( - table, - self.connection, - checkfirst=self.checkfirst, - _ddl_runner=self, - _is_metadata_operation=_is_metadata_operation, - ) + # traverse client side defaults which may refer to server-side + # sequences. noting that some of these client side defaults may + # also be set up as server side defaults + # (see https://docs.sqlalchemy.org/en/ + # latest/core/defaults.html + # #associating-a-sequence-as-the-server-side- + # default), so have to be dropped after the table is dropped. + for column in table.columns: + if ( + column.default is not None + and column.default not in _ignore_sequences + ): + self.traverse_single(column.default) def visit_foreign_key_constraint(self, constraint): if not self.dialect.supports_alter: return - DropConstraint(constraint)._invoke_with(self.connection) + with self.with_ddl_events(constraint): + DropConstraint(constraint)._invoke_with(self.connection) def visit_sequence(self, sequence, drop_ok=False): if not drop_ok and not self._can_drop_sequence(sequence): return - DropSequence(sequence)._invoke_with(self.connection) + with self.with_ddl_events(sequence): + DropSequence(sequence)._invoke_with(self.connection) def sort_tables( |
