summaryrefslogtreecommitdiff
path: root/alembic/autogenerate/compare.py
diff options
context:
space:
mode:
Diffstat (limited to 'alembic/autogenerate/compare.py')
-rw-r--r--alembic/autogenerate/compare.py148
1 files changed, 98 insertions, 50 deletions
diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py
index eb649ac..6577025 100644
--- a/alembic/autogenerate/compare.py
+++ b/alembic/autogenerate/compare.py
@@ -1,4 +1,3 @@
-import collections
from sqlalchemy import schema as sa_schema, types as sqltypes
from sqlalchemy import event
import logging
@@ -7,6 +6,7 @@ from sqlalchemy.util import OrderedSet
import re
from .render import _user_defined_render
import contextlib
+from alembic.ddl.base import _fk_spec
log = logging.getLogger(__name__)
@@ -139,6 +139,25 @@ def _make_unique_constraint(params, conn_table):
)
+def _make_foreign_key(params, conn_table):
+ tname = params['referred_table']
+ if params['referred_schema']:
+ tname = "%s.%s" % (params['referred_schema'], tname)
+
+ const = sa_schema.ForeignKeyConstraint(
+ [conn_table.c[cname] for cname in params['constrained_columns']],
+ ["%s.%s" % (tname, n) for n in params['referred_columns']],
+ onupdate=params.get('onupdate'),
+ ondelete=params.get('ondelete'),
+ deferrable=params.get('deferrable'),
+ initially=params.get('initially'),
+ name=params['name']
+ )
+ # needed by 0.7
+ conn_table.append_constraint(const)
+ return const
+
+
@contextlib.contextmanager
def _compare_columns(schema, tname, object_filters, conn_table, metadata_table,
diffs, autogen_context, inspector):
@@ -194,7 +213,6 @@ def _compare_columns(schema, tname, object_filters, conn_table, metadata_table,
log.info("Detected removed column '%s.%s'", name, cname)
-
class _constraint_sig(object):
def __eq__(self, other):
@@ -235,6 +253,20 @@ class _ix_constraint_sig(_constraint_sig):
return _get_index_column_names(self.const)
+class _fk_constraint_sig(_constraint_sig):
+ def __init__(self, const):
+ self.const = const
+ self.name = const.name
+ self.source_schema, self.source_table, \
+ self.source_columns, self.target_schema, self.target_table, \
+ self.target_columns = _fk_spec(const)
+
+ self.sig = (
+ self.source_schema, self.source_table, tuple(self.source_columns),
+ self.target_schema, self.target_table, tuple(self.target_columns)
+ )
+
+
def _get_index_column_names(idx):
if compat.sqla_08:
return [getattr(exp, "name", None) for exp in idx.expressions]
@@ -571,63 +603,79 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col,
cname
)
-FKInfo = collections.namedtuple('fk_info', ['constrained_columns',
- 'referred_table',
- 'referred_columns'])
-
def _compare_foreign_keys(schema, tname, object_filters, conn_table,
metadata_table, diffs, autogen_context, inspector):
- # This methods checks foreign keys that tables contain in models with
- # foreign keys that are in db.
- # Get all necessary information about key of current table from db
+ # if we're doing CREATE TABLE, all FKs are created
+ # inline within the table def
if conn_table is None:
return
- fk_db = {}
- if hasattr(inspector, "get_foreign_keys"):
- try:
- fk_db = dict((_get_fk_info_from_db(i), i['name']) for i in
- inspector.get_foreign_keys(tname, schema=schema))
- except NotImplementedError:
- pass
-
+ metadata_fks = set(
+ fk for fk in metadata_table.constraints
+ if isinstance(fk, sa_schema.ForeignKeyConstraint)
+ )
+ metadata_fks = set(_fk_constraint_sig(fk) for fk in metadata_fks)
- # Get all necessary information about key of current table from
- # models
- fk_models = dict((_get_fk_info_from_model(fk), fk) for fk in
- metadata_table.foreign_keys)
- fk_models_set = set(fk_models.keys())
- fk_db_set = set(fk_db.keys())
- for key in (fk_db_set - fk_models_set):
- diffs.append(('drop_fk', fk_db[key], conn_table, key))
- log.info(("Detected removed foreign key %(fk)r on "
- "table %(table)r"), {'fk': fk_db[key],
- 'table': conn_table})
- for key in (fk_models_set - fk_db_set):
- diffs.append(('add_fk', fk_models[key], key))
- log.info((
- "Detected added foreign key for column %(fk)r on table "
- "%(table)r"), {'fk': fk_models[key].column.name,
- 'table': conn_table})
- return diffs
+ conn_fks = inspector.get_foreign_keys(tname, schema=schema)
+ conn_fks = set(_fk_constraint_sig(_make_foreign_key(const, conn_table))
+ for const in conn_fks)
+ conn_fks_by_sig = dict(
+ (c.sig, c) for c in conn_fks
+ )
+ metadata_fks_by_sig = dict(
+ (c.sig, c) for c in metadata_fks
+ )
-def _get_fk_info_from_db(fk):
- return FKInfo(tuple(fk['constrained_columns']),
- fk['referred_table'],
- tuple(fk['referred_columns']))
+ metadata_fks_by_name = dict(
+ (c.name, c) for c in metadata_fks if c.name is not None
+ )
+ conn_fks_by_name = dict(
+ (c.name, c) for c in conn_fks if c.name is not None
+ )
+ def _add_fk(obj, compare_to):
+ if _run_filters(
+ obj.const, obj.name, "foreignkey", False,
+ compare_to, object_filters):
+ diffs.append(('add_fk', const.const))
+
+ log.info(
+ "Detected added foreign key (%s)(%s) on table %s%s",
+ ", ".join(obj.source_columns),
+ ", ".join(obj.target_columns),
+ "%s." % obj.source_schema if obj.source_schema else "",
+ obj.source_table)
+
+ def _remove_fk(obj, compare_to):
+ if _run_filters(
+ obj.const, obj.name, "foreignkey", True,
+ compare_to, object_filters):
+ diffs.append(('remove_fk', obj.const))
+ log.info(
+ "Detected removed foreign key (%s)(%s) on table %s%s",
+ ", ".join(obj.source_columns),
+ ", ".join(obj.target_columns),
+ "%s." % obj.source_schema if obj.source_schema else "",
+ obj.source_table)
+
+ # so far it appears we don't need to do this by name at all.
+ # SQLite doesn't preserve constraint names anyway
+
+ for removed_sig in set(conn_fks_by_sig).difference(metadata_fks_by_sig):
+ const = conn_fks_by_sig[removed_sig]
+ if removed_sig not in metadata_fks_by_sig:
+ compare_to = metadata_fks_by_name[const.name].const \
+ if const.name in metadata_fks_by_name else None
+ _remove_fk(const, compare_to)
+
+ for added_sig in set(metadata_fks_by_sig).difference(conn_fks_by_sig):
+ const = metadata_fks_by_sig[added_sig]
+ if added_sig not in conn_fks_by_sig:
+ compare_to = conn_fks_by_name[const.name].const \
+ if const.name in conn_fks_by_name else None
+ _add_fk(const, compare_to)
-def _get_fk_info_from_model(fk):
- constrained_columns = []
- for column in fk.constraint.columns:
- if not isinstance(column, basestring):
- constrained_columns.append(column.name)
- else:
- constrained_columns.append(column)
- return FKInfo(
- tuple(constrained_columns),
- fk.column.table.name,
- tuple(k.column.name for k in fk.constraint._elements.values()))
+ return diffs