summaryrefslogtreecommitdiff
path: root/lib/sqlalchemy/sql/ddl.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/sqlalchemy/sql/ddl.py')
-rw-r--r--lib/sqlalchemy/sql/ddl.py268
1 files changed, 140 insertions, 128 deletions
diff --git a/lib/sqlalchemy/sql/ddl.py b/lib/sqlalchemy/sql/ddl.py
index 08d1072c7..3c7c674f5 100644
--- a/lib/sqlalchemy/sql/ddl.py
+++ b/lib/sqlalchemy/sql/ddl.py
@@ -13,6 +13,7 @@ to invoke them for a create/drop call.
"""
from __future__ import annotations
+import contextlib
import typing
from typing import Any
from typing import Callable
@@ -538,7 +539,7 @@ class CreateTable(_CreateDropBase):
.. versionadded:: 1.4.0b2
"""
- super(CreateTable, self).__init__(element, if_not_exists=if_not_exists)
+ super().__init__(element, if_not_exists=if_not_exists)
self.columns = [CreateColumn(column) for column in element.columns]
self.include_foreign_key_constraints = include_foreign_key_constraints
@@ -685,7 +686,7 @@ class DropTable(_CreateDropBase):
.. versionadded:: 1.4.0b2
"""
- super(DropTable, self).__init__(element, if_exists=if_exists)
+ super().__init__(element, if_exists=if_exists)
class CreateSequence(_CreateDropBase):
@@ -717,7 +718,7 @@ class CreateIndex(_CreateDropBase):
.. versionadded:: 1.4.0b2
"""
- super(CreateIndex, self).__init__(element, if_not_exists=if_not_exists)
+ super().__init__(element, if_not_exists=if_not_exists)
class DropIndex(_CreateDropBase):
@@ -737,7 +738,7 @@ class DropIndex(_CreateDropBase):
.. versionadded:: 1.4.0b2
"""
- super(DropIndex, self).__init__(element, if_exists=if_exists)
+ super().__init__(element, if_exists=if_exists)
class AddConstraint(_CreateDropBase):
@@ -746,7 +747,7 @@ class AddConstraint(_CreateDropBase):
__visit_name__ = "add_constraint"
def __init__(self, element, *args, **kw):
- super(AddConstraint, self).__init__(element, *args, **kw)
+ super().__init__(element, *args, **kw)
element._create_rule = util.portable_instancemethod(
self._create_rule_disable
)
@@ -759,7 +760,7 @@ class DropConstraint(_CreateDropBase):
def __init__(self, element, cascade=False, **kw):
self.cascade = cascade
- super(DropConstraint, self).__init__(element, **kw)
+ super().__init__(element, **kw)
element._create_rule = util.portable_instancemethod(
self._create_rule_disable
)
@@ -809,12 +810,49 @@ class InvokeDDLBase(SchemaVisitor):
def __init__(self, connection):
self.connection = connection
+ @contextlib.contextmanager
+ def with_ddl_events(self, target, **kw):
+ """helper context manager that will apply appropriate DDL events
+ to a CREATE or DROP operation."""
-class SchemaGenerator(InvokeDDLBase):
+ raise NotImplementedError()
+
+
+class InvokeCreateDDLBase(InvokeDDLBase):
+ @contextlib.contextmanager
+ def with_ddl_events(self, target, **kw):
+ """helper context manager that will apply appropriate DDL events
+ to a CREATE or DROP operation."""
+
+ target.dispatch.before_create(
+ target, self.connection, _ddl_runner=self, **kw
+ )
+ yield
+ target.dispatch.after_create(
+ target, self.connection, _ddl_runner=self, **kw
+ )
+
+
+class InvokeDropDDLBase(InvokeDDLBase):
+ @contextlib.contextmanager
+ def with_ddl_events(self, target, **kw):
+ """helper context manager that will apply appropriate DDL events
+ to a CREATE or DROP operation."""
+
+ target.dispatch.before_drop(
+ target, self.connection, _ddl_runner=self, **kw
+ )
+ yield
+ target.dispatch.after_drop(
+ target, self.connection, _ddl_runner=self, **kw
+ )
+
+
+class SchemaGenerator(InvokeCreateDDLBase):
def __init__(
self, dialect, connection, checkfirst=False, tables=None, **kwargs
):
- super(SchemaGenerator, self).__init__(connection, **kwargs)
+ super().__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables
self.preparer = dialect.identifier_preparer
@@ -871,36 +909,26 @@ class SchemaGenerator(InvokeDDLBase):
]
event_collection = [t for (t, fks) in collection if t is not None]
- metadata.dispatch.before_create(
- metadata,
- self.connection,
- tables=event_collection,
- checkfirst=self.checkfirst,
- _ddl_runner=self,
- )
-
- for seq in seq_coll:
- self.traverse_single(seq, create_ok=True)
-
- for table, fkcs in collection:
- if table is not None:
- self.traverse_single(
- table,
- create_ok=True,
- include_foreign_key_constraints=fkcs,
- _is_metadata_operation=True,
- )
- else:
- for fkc in fkcs:
- self.traverse_single(fkc)
- metadata.dispatch.after_create(
+ with self.with_ddl_events(
metadata,
- self.connection,
tables=event_collection,
checkfirst=self.checkfirst,
- _ddl_runner=self,
- )
+ ):
+ for seq in seq_coll:
+ self.traverse_single(seq, create_ok=True)
+
+ for table, fkcs in collection:
+ if table is not None:
+ self.traverse_single(
+ table,
+ create_ok=True,
+ include_foreign_key_constraints=fkcs,
+ _is_metadata_operation=True,
+ )
+ else:
+ for fkc in fkcs:
+ self.traverse_single(fkc)
def visit_table(
self,
@@ -912,75 +940,74 @@ class SchemaGenerator(InvokeDDLBase):
if not create_ok and not self._can_create_table(table):
return
- table.dispatch.before_create(
+ with self.with_ddl_events(
table,
- self.connection,
checkfirst=self.checkfirst,
- _ddl_runner=self,
_is_metadata_operation=_is_metadata_operation,
- )
-
- for column in table.columns:
- if column.default is not None:
- self.traverse_single(column.default)
+ ):
- if not self.dialect.supports_alter:
- # e.g., don't omit any foreign key constraints
- include_foreign_key_constraints = None
+ for column in table.columns:
+ if column.default is not None:
+ self.traverse_single(column.default)
- CreateTable(
- table,
- include_foreign_key_constraints=include_foreign_key_constraints,
- )._invoke_with(self.connection)
+ if not self.dialect.supports_alter:
+ # e.g., don't omit any foreign key constraints
+ include_foreign_key_constraints = None
- if hasattr(table, "indexes"):
- for index in table.indexes:
- self.traverse_single(index, create_ok=True)
+ CreateTable(
+ table,
+ include_foreign_key_constraints=(
+ include_foreign_key_constraints
+ ),
+ )._invoke_with(self.connection)
- if self.dialect.supports_comments and not self.dialect.inline_comments:
- if table.comment is not None:
- SetTableComment(table)._invoke_with(self.connection)
+ if hasattr(table, "indexes"):
+ for index in table.indexes:
+ self.traverse_single(index, create_ok=True)
- for column in table.columns:
- if column.comment is not None:
- SetColumnComment(column)._invoke_with(self.connection)
+ if (
+ self.dialect.supports_comments
+ and not self.dialect.inline_comments
+ ):
+ if table.comment is not None:
+ SetTableComment(table)._invoke_with(self.connection)
- if self.dialect.supports_constraint_comments:
- for constraint in table.constraints:
- if constraint.comment is not None:
- self.connection.execute(
- SetConstraintComment(constraint)
- )
+ for column in table.columns:
+ if column.comment is not None:
+ SetColumnComment(column)._invoke_with(self.connection)
- table.dispatch.after_create(
- table,
- self.connection,
- checkfirst=self.checkfirst,
- _ddl_runner=self,
- _is_metadata_operation=_is_metadata_operation,
- )
+ if self.dialect.supports_constraint_comments:
+ for constraint in table.constraints:
+ if constraint.comment is not None:
+ self.connection.execute(
+ SetConstraintComment(constraint)
+ )
def visit_foreign_key_constraint(self, constraint):
if not self.dialect.supports_alter:
return
- AddConstraint(constraint)._invoke_with(self.connection)
+
+ with self.with_ddl_events(constraint):
+ AddConstraint(constraint)._invoke_with(self.connection)
def visit_sequence(self, sequence, create_ok=False):
if not create_ok and not self._can_create_sequence(sequence):
return
- CreateSequence(sequence)._invoke_with(self.connection)
+ with self.with_ddl_events(sequence):
+ CreateSequence(sequence)._invoke_with(self.connection)
def visit_index(self, index, create_ok=False):
if not create_ok and not self._can_create_index(index):
return
- CreateIndex(index)._invoke_with(self.connection)
+ with self.with_ddl_events(index):
+ CreateIndex(index)._invoke_with(self.connection)
-class SchemaDropper(InvokeDDLBase):
+class SchemaDropper(InvokeDropDDLBase):
def __init__(
self, dialect, connection, checkfirst=False, tables=None, **kwargs
):
- super(SchemaDropper, self).__init__(connection, **kwargs)
+ super().__init__(connection, **kwargs)
self.checkfirst = checkfirst
self.tables = tables
self.preparer = dialect.identifier_preparer
@@ -1043,36 +1070,26 @@ class SchemaDropper(InvokeDDLBase):
event_collection = [t for (t, fks) in collection if t is not None]
- metadata.dispatch.before_drop(
+ with self.with_ddl_events(
metadata,
- self.connection,
tables=event_collection,
checkfirst=self.checkfirst,
- _ddl_runner=self,
- )
+ ):
- for table, fkcs in collection:
- if table is not None:
- self.traverse_single(
- table,
- drop_ok=True,
- _is_metadata_operation=True,
- _ignore_sequences=seq_coll,
- )
- else:
- for fkc in fkcs:
- self.traverse_single(fkc)
+ for table, fkcs in collection:
+ if table is not None:
+ self.traverse_single(
+ table,
+ drop_ok=True,
+ _is_metadata_operation=True,
+ _ignore_sequences=seq_coll,
+ )
+ else:
+ for fkc in fkcs:
+ self.traverse_single(fkc)
- for seq in seq_coll:
- self.traverse_single(seq, drop_ok=seq.column is None)
-
- metadata.dispatch.after_drop(
- metadata,
- self.connection,
- tables=event_collection,
- checkfirst=self.checkfirst,
- _ddl_runner=self,
- )
+ for seq in seq_coll:
+ self.traverse_single(seq, drop_ok=seq.column is None)
def _can_drop_table(self, table):
self.dialect.validate_identifier(table.name)
@@ -1110,7 +1127,8 @@ class SchemaDropper(InvokeDDLBase):
if not drop_ok and not self._can_drop_index(index):
return
- DropIndex(index)(index, self.connection)
+ with self.with_ddl_events(index):
+ DropIndex(index)(index, self.connection)
def visit_table(
self,
@@ -1122,46 +1140,40 @@ class SchemaDropper(InvokeDDLBase):
if not drop_ok and not self._can_drop_table(table):
return
- table.dispatch.before_drop(
+ with self.with_ddl_events(
table,
- self.connection,
checkfirst=self.checkfirst,
- _ddl_runner=self,
_is_metadata_operation=_is_metadata_operation,
- )
-
- DropTable(table)._invoke_with(self.connection)
+ ):
- # traverse client side defaults which may refer to server-side
- # sequences. noting that some of these client side defaults may also be
- # set up as server side defaults (see https://docs.sqlalchemy.org/en/
- # latest/core/defaults.html#associating-a-sequence-as-the-server-side-
- # default), so have to be dropped after the table is dropped.
- for column in table.columns:
- if (
- column.default is not None
- and column.default not in _ignore_sequences
- ):
- self.traverse_single(column.default)
+ DropTable(table)._invoke_with(self.connection)
- table.dispatch.after_drop(
- table,
- self.connection,
- checkfirst=self.checkfirst,
- _ddl_runner=self,
- _is_metadata_operation=_is_metadata_operation,
- )
+ # traverse client side defaults which may refer to server-side
+ # sequences. noting that some of these client side defaults may
+ # also be set up as server side defaults
+ # (see https://docs.sqlalchemy.org/en/
+ # latest/core/defaults.html
+ # #associating-a-sequence-as-the-server-side-
+ # default), so have to be dropped after the table is dropped.
+ for column in table.columns:
+ if (
+ column.default is not None
+ and column.default not in _ignore_sequences
+ ):
+ self.traverse_single(column.default)
def visit_foreign_key_constraint(self, constraint):
if not self.dialect.supports_alter:
return
- DropConstraint(constraint)._invoke_with(self.connection)
+ with self.with_ddl_events(constraint):
+ DropConstraint(constraint)._invoke_with(self.connection)
def visit_sequence(self, sequence, drop_ok=False):
if not drop_ok and not self._can_drop_sequence(sequence):
return
- DropSequence(sequence)._invoke_with(self.connection)
+ with self.with_ddl_events(sequence):
+ DropSequence(sequence)._invoke_with(self.connection)
def sort_tables(