diff options
Diffstat (limited to 'alembic/autogenerate/compare.py')
-rw-r--r-- | alembic/autogenerate/compare.py | 148 |
1 files changed, 98 insertions, 50 deletions
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index eb649ac..6577025 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -1,4 +1,3 @@ -import collections from sqlalchemy import schema as sa_schema, types as sqltypes from sqlalchemy import event import logging @@ -7,6 +6,7 @@ from sqlalchemy.util import OrderedSet import re from .render import _user_defined_render import contextlib +from alembic.ddl.base import _fk_spec log = logging.getLogger(__name__) @@ -139,6 +139,25 @@ def _make_unique_constraint(params, conn_table): ) +def _make_foreign_key(params, conn_table): + tname = params['referred_table'] + if params['referred_schema']: + tname = "%s.%s" % (params['referred_schema'], tname) + + const = sa_schema.ForeignKeyConstraint( + [conn_table.c[cname] for cname in params['constrained_columns']], + ["%s.%s" % (tname, n) for n in params['referred_columns']], + onupdate=params.get('onupdate'), + ondelete=params.get('ondelete'), + deferrable=params.get('deferrable'), + initially=params.get('initially'), + name=params['name'] + ) + # needed by 0.7 + conn_table.append_constraint(const) + return const + + @contextlib.contextmanager def _compare_columns(schema, tname, object_filters, conn_table, metadata_table, diffs, autogen_context, inspector): @@ -194,7 +213,6 @@ def _compare_columns(schema, tname, object_filters, conn_table, metadata_table, log.info("Detected removed column '%s.%s'", name, cname) - class _constraint_sig(object): def __eq__(self, other): @@ -235,6 +253,20 @@ class _ix_constraint_sig(_constraint_sig): return _get_index_column_names(self.const) +class _fk_constraint_sig(_constraint_sig): + def __init__(self, const): + self.const = const + self.name = const.name + self.source_schema, self.source_table, \ + self.source_columns, self.target_schema, self.target_table, \ + self.target_columns = _fk_spec(const) + + self.sig = ( + self.source_schema, self.source_table, tuple(self.source_columns), + self.target_schema, self.target_table, tuple(self.target_columns) + ) + + def _get_index_column_names(idx): if compat.sqla_08: return [getattr(exp, "name", None) for exp in idx.expressions] @@ -571,63 +603,79 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col, cname ) -FKInfo = collections.namedtuple('fk_info', ['constrained_columns', - 'referred_table', - 'referred_columns']) - def _compare_foreign_keys(schema, tname, object_filters, conn_table, metadata_table, diffs, autogen_context, inspector): - # This methods checks foreign keys that tables contain in models with - # foreign keys that are in db. - # Get all necessary information about key of current table from db + # if we're doing CREATE TABLE, all FKs are created + # inline within the table def if conn_table is None: return - fk_db = {} - if hasattr(inspector, "get_foreign_keys"): - try: - fk_db = dict((_get_fk_info_from_db(i), i['name']) for i in - inspector.get_foreign_keys(tname, schema=schema)) - except NotImplementedError: - pass - + metadata_fks = set( + fk for fk in metadata_table.constraints + if isinstance(fk, sa_schema.ForeignKeyConstraint) + ) + metadata_fks = set(_fk_constraint_sig(fk) for fk in metadata_fks) - # Get all necessary information about key of current table from - # models - fk_models = dict((_get_fk_info_from_model(fk), fk) for fk in - metadata_table.foreign_keys) - fk_models_set = set(fk_models.keys()) - fk_db_set = set(fk_db.keys()) - for key in (fk_db_set - fk_models_set): - diffs.append(('drop_fk', fk_db[key], conn_table, key)) - log.info(("Detected removed foreign key %(fk)r on " - "table %(table)r"), {'fk': fk_db[key], - 'table': conn_table}) - for key in (fk_models_set - fk_db_set): - diffs.append(('add_fk', fk_models[key], key)) - log.info(( - "Detected added foreign key for column %(fk)r on table " - "%(table)r"), {'fk': fk_models[key].column.name, - 'table': conn_table}) - return diffs + conn_fks = inspector.get_foreign_keys(tname, schema=schema) + conn_fks = set(_fk_constraint_sig(_make_foreign_key(const, conn_table)) + for const in conn_fks) + conn_fks_by_sig = dict( + (c.sig, c) for c in conn_fks + ) + metadata_fks_by_sig = dict( + (c.sig, c) for c in metadata_fks + ) -def _get_fk_info_from_db(fk): - return FKInfo(tuple(fk['constrained_columns']), - fk['referred_table'], - tuple(fk['referred_columns'])) + metadata_fks_by_name = dict( + (c.name, c) for c in metadata_fks if c.name is not None + ) + conn_fks_by_name = dict( + (c.name, c) for c in conn_fks if c.name is not None + ) + def _add_fk(obj, compare_to): + if _run_filters( + obj.const, obj.name, "foreignkey", False, + compare_to, object_filters): + diffs.append(('add_fk', const.const)) + + log.info( + "Detected added foreign key (%s)(%s) on table %s%s", + ", ".join(obj.source_columns), + ", ".join(obj.target_columns), + "%s." % obj.source_schema if obj.source_schema else "", + obj.source_table) + + def _remove_fk(obj, compare_to): + if _run_filters( + obj.const, obj.name, "foreignkey", True, + compare_to, object_filters): + diffs.append(('remove_fk', obj.const)) + log.info( + "Detected removed foreign key (%s)(%s) on table %s%s", + ", ".join(obj.source_columns), + ", ".join(obj.target_columns), + "%s." % obj.source_schema if obj.source_schema else "", + obj.source_table) + + # so far it appears we don't need to do this by name at all. + # SQLite doesn't preserve constraint names anyway + + for removed_sig in set(conn_fks_by_sig).difference(metadata_fks_by_sig): + const = conn_fks_by_sig[removed_sig] + if removed_sig not in metadata_fks_by_sig: + compare_to = metadata_fks_by_name[const.name].const \ + if const.name in metadata_fks_by_name else None + _remove_fk(const, compare_to) + + for added_sig in set(metadata_fks_by_sig).difference(conn_fks_by_sig): + const = metadata_fks_by_sig[added_sig] + if added_sig not in conn_fks_by_sig: + compare_to = conn_fks_by_name[const.name].const \ + if const.name in conn_fks_by_name else None + _add_fk(const, compare_to) -def _get_fk_info_from_model(fk): - constrained_columns = [] - for column in fk.constraint.columns: - if not isinstance(column, basestring): - constrained_columns.append(column.name) - else: - constrained_columns.append(column) - return FKInfo( - tuple(constrained_columns), - fk.column.table.name, - tuple(k.column.name for k in fk.constraint._elements.values())) + return diffs |