diff options
44 files changed, 1661 insertions, 1357 deletions
diff --git a/alembic/__init__.py b/alembic/__init__.py index f2d3932..56a254d 100644 --- a/alembic/__init__.py +++ b/alembic/__init__.py @@ -7,5 +7,3 @@ package_dir = path.abspath(path.dirname(__file__)) from . import op from . import context - - diff --git a/alembic/autogenerate/api.py b/alembic/autogenerate/api.py index 148e352..b13a57b 100644 --- a/alembic/autogenerate/api.py +++ b/alembic/autogenerate/api.py @@ -8,13 +8,15 @@ from sqlalchemy.engine.reflection import Inspector from sqlalchemy.util import OrderedSet from .compare import _compare_tables from .render import _drop_table, _drop_column, _drop_index, _drop_constraint, \ - _add_table, _add_column, _add_index, _add_constraint, _modify_col + _add_table, _add_column, _add_index, _add_constraint, _modify_col from .. import util log = logging.getLogger(__name__) ################################################### # public + + def compare_metadata(context, metadata): """Compare a database schema to that given in a :class:`~sqlalchemy.schema.MetaData` instance. @@ -113,10 +115,11 @@ def compare_metadata(context, metadata): ################################################### # top level + def _produce_migration_diffs(context, template_args, - imports, include_symbol=None, - include_object=None, - include_schemas=False): + imports, include_symbol=None, + include_object=None, + include_schemas=False): opts = context.opts metadata = opts['target_metadata'] include_schemas = opts.get('include_schemas', include_schemas) @@ -125,20 +128,20 @@ def _produce_migration_diffs(context, template_args, if metadata is None: raise util.CommandError( - "Can't proceed with --autogenerate option; environment " - "script %s does not provide " - "a MetaData object to the context." % ( - context.script.env_py_location - )) + "Can't proceed with --autogenerate option; environment " + "script %s does not provide " + "a MetaData object to the context." % ( + context.script.env_py_location + )) autogen_context, connection = _autogen_context(context, imports) diffs = [] _produce_net_changes(connection, metadata, diffs, - autogen_context, object_filters, include_schemas) + autogen_context, object_filters, include_schemas) template_args[opts['upgrade_token']] = \ - _indent(_produce_upgrade_commands(diffs, autogen_context)) + _indent(_produce_upgrade_commands(diffs, autogen_context)) template_args[opts['downgrade_token']] = \ - _indent(_produce_downgrade_commands(diffs, autogen_context)) + _indent(_produce_downgrade_commands(diffs, autogen_context)) template_args['imports'] = "\n".join(sorted(imports)) @@ -171,9 +174,10 @@ def _autogen_context(context, imports): 'opts': opts }, connection + def _indent(text): text = "### commands auto generated by Alembic - "\ - "please adjust! ###\n" + text + "please adjust! ###\n" + text text += "\n### end Alembic commands ###" text = re.compile(r'^', re.M).sub(" ", text).strip() return text @@ -183,8 +187,8 @@ def _indent(text): def _produce_net_changes(connection, metadata, diffs, autogen_context, - object_filters=(), - include_schemas=False): + object_filters=(), + include_schemas=False): inspector = Inspector.from_engine(connection) # TODO: not hardcode alembic_version here ? conn_table_names = set() @@ -202,11 +206,11 @@ def _produce_net_changes(connection, metadata, diffs, autogen_context, for s in schemas: tables = set(inspector.get_table_names(schema=s)).\ - difference(['alembic_version']) + difference(['alembic_version']) conn_table_names.update(zip([s] * len(tables), tables)) metadata_table_names = OrderedSet([(table.schema, table.name) - for table in metadata.sorted_tables]) + for table in metadata.sorted_tables]) _compare_tables(conn_table_names, metadata_table_names, object_filters, @@ -232,6 +236,7 @@ def _produce_upgrade_commands(diffs, autogen_context): buf = ["pass"] return "\n".join(buf) + def _produce_downgrade_commands(diffs, autogen_context): buf = [] for diff in reversed(diffs): @@ -240,12 +245,14 @@ def _produce_downgrade_commands(diffs, autogen_context): buf = ["pass"] return "\n".join(buf) + def _invoke_command(updown, args, autogen_context): if isinstance(args, tuple): return _invoke_adddrop_command(updown, args, autogen_context) else: return _invoke_modify_command(updown, args, autogen_context) + def _invoke_adddrop_command(updown, args, autogen_context): cmd_type = args[0] adddrop, cmd_type = cmd_type.split("_") @@ -270,6 +277,7 @@ def _invoke_adddrop_command(updown, args, autogen_context): else: return cmd_callables[0](*cmd_args) + def _invoke_modify_command(updown, args, autogen_context): sname, tname, cname = args[0][1:4] kw = {} @@ -281,9 +289,9 @@ def _invoke_modify_command(updown, args, autogen_context): } for diff in args: diff_kw = diff[4] - for arg in ("existing_type", \ - "existing_nullable", \ - "existing_server_default"): + for arg in ("existing_type", + "existing_nullable", + "existing_server_default"): if arg in diff_kw: kw.setdefault(arg, diff_kw[arg]) old_kw, new_kw = _arg_struct[diff[0]] diff --git a/alembic/autogenerate/compare.py b/alembic/autogenerate/compare.py index a50bc6d..cc24173 100644 --- a/alembic/autogenerate/compare.py +++ b/alembic/autogenerate/compare.py @@ -8,6 +8,7 @@ from sqlalchemy.util import OrderedSet log = logging.getLogger(__name__) + def _run_filters(object_, name, type_, reflected, compare_to, object_filters): for fn in object_filters: if not fn(object_, name, type_, reflected, compare_to): @@ -15,6 +16,7 @@ def _run_filters(object_, name, type_, reflected, compare_to, object_filters): else: return True + def _compare_tables(conn_table_names, metadata_table_names, object_filters, inspector, metadata, diffs, autogen_context): @@ -35,14 +37,14 @@ def _compare_tables(conn_table_names, metadata_table_names, # as "schemaname.tablename" or just "tablename", create a new lookup # which will match the "non-default-schema" keys to the Table object. tname_to_table = dict( - ( - no_dflt_schema, - metadata.tables[sa_schema._get_table_key(tname, schema)] - ) - for no_dflt_schema, (schema, tname) in zip( - metadata_table_names_no_dflt_schema, - metadata_table_names) - ) + ( + no_dflt_schema, + metadata.tables[sa_schema._get_table_key(tname, schema)] + ) + for no_dflt_schema, (schema, tname) in zip( + metadata_table_names_no_dflt_schema, + metadata_table_names) + ) metadata_table_names = metadata_table_names_no_dflt_schema for s, tname in metadata_table_names.difference(conn_table_names): @@ -52,9 +54,9 @@ def _compare_tables(conn_table_names, metadata_table_names, diffs.append(("add_table", metadata_table)) log.info("Detected added table %r", name) _compare_indexes_and_uniques(s, tname, object_filters, - None, - metadata_table, - diffs, autogen_context, inspector) + None, + metadata_table, + diffs, autogen_context, inspector) removal_metadata = sa_schema.MetaData() for s, tname in conn_table_names.difference(metadata_table_names): @@ -87,33 +89,36 @@ def _compare_tables(conn_table_names, metadata_table_names, if _run_filters(metadata_table, tname, "table", False, conn_table, object_filters): _compare_columns(s, tname, object_filters, - conn_table, - metadata_table, - diffs, autogen_context, inspector) + conn_table, + metadata_table, + diffs, autogen_context, inspector) _compare_indexes_and_uniques(s, tname, object_filters, - conn_table, - metadata_table, - diffs, autogen_context, inspector) + conn_table, + metadata_table, + diffs, autogen_context, inspector) # TODO: # table constraints # sequences + def _make_index(params, conn_table): return sa_schema.Index( - params['name'], - *[conn_table.c[cname] for cname in params['column_names']], - unique=params['unique'] + params['name'], + *[conn_table.c[cname] for cname in params['column_names']], + unique=params['unique'] ) + def _make_unique_constraint(params, conn_table): return sa_schema.UniqueConstraint( - *[conn_table.c[cname] for cname in params['column_names']], - name=params['name'] + *[conn_table.c[cname] for cname in params['column_names']], + name=params['name'] ) + def _compare_columns(schema, tname, object_filters, conn_table, metadata_table, - diffs, autogen_context, inspector): + diffs, autogen_context, inspector): name = '%s.%s' % (schema, tname) if schema else tname metadata_cols_by_name = dict((c.name, c) for c in metadata_table.c) conn_col_names = dict((c.name, c) for c in conn_table.c) @@ -121,7 +126,7 @@ def _compare_columns(schema, tname, object_filters, conn_table, metadata_table, for cname in metadata_col_names.difference(conn_col_names): if _run_filters(metadata_cols_by_name[cname], cname, - "column", False, None, object_filters): + "column", False, None, object_filters): diffs.append( ("add_column", schema, tname, metadata_cols_by_name[cname]) ) @@ -129,7 +134,7 @@ def _compare_columns(schema, tname, object_filters, conn_table, metadata_table, for cname in set(conn_col_names).difference(metadata_col_names): if _run_filters(conn_table.c[cname], cname, - "column", True, None, object_filters): + "column", True, None, object_filters): diffs.append( ("remove_column", schema, tname, conn_table.c[cname]) ) @@ -139,28 +144,30 @@ def _compare_columns(schema, tname, object_filters, conn_table, metadata_table, metadata_col = metadata_cols_by_name[colname] conn_col = conn_table.c[colname] if not _run_filters( - metadata_col, colname, "column", False, conn_col, object_filters): + metadata_col, colname, "column", False, conn_col, object_filters): continue col_diff = [] _compare_type(schema, tname, colname, - conn_col, - metadata_col, - col_diff, autogen_context - ) + conn_col, + metadata_col, + col_diff, autogen_context + ) _compare_nullable(schema, tname, colname, - conn_col, - metadata_col.nullable, - col_diff, autogen_context - ) + conn_col, + metadata_col.nullable, + col_diff, autogen_context + ) _compare_server_default(schema, tname, colname, - conn_col, - metadata_col, - col_diff, autogen_context - ) + conn_col, + metadata_col, + col_diff, autogen_context + ) if col_diff: diffs.append(col_diff) + class _constraint_sig(object): + def __eq__(self, other): return self.const == other.const @@ -170,6 +177,7 @@ class _constraint_sig(object): def __hash__(self): return hash(self.const) + class _uq_constraint_sig(_constraint_sig): is_index = False is_unique = True @@ -183,6 +191,7 @@ class _uq_constraint_sig(_constraint_sig): def column_names(self): return [col.name for col in self.const.columns] + class _ix_constraint_sig(_constraint_sig): is_index = True @@ -196,21 +205,23 @@ class _ix_constraint_sig(_constraint_sig): def column_names(self): return _get_index_column_names(self.const) + def _get_index_column_names(idx): if compat.sqla_08: return [getattr(exp, "name", None) for exp in idx.expressions] else: return [getattr(col, "name", None) for col in idx.columns] + def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, - metadata_table, diffs, autogen_context, inspector): + metadata_table, diffs, autogen_context, inspector): is_create_table = conn_table is None # 1a. get raw indexes and unique constraints from metadata ... metadata_unique_constraints = set(uq for uq in metadata_table.constraints - if isinstance(uq, sa_schema.UniqueConstraint) - ) + if isinstance(uq, sa_schema.UniqueConstraint) + ) metadata_indexes = set(metadata_table.indexes) conn_uniques = conn_indexes = frozenset() @@ -222,7 +233,7 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, if hasattr(inspector, "get_unique_constraints"): try: conn_uniques = inspector.get_unique_constraints( - tname, schema=schema) + tname, schema=schema) supports_unique_constraints = True except NotImplementedError: pass @@ -234,26 +245,26 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, # 2. convert conn-level objects from raw inspector records # into schema objects conn_uniques = set(_make_unique_constraint(uq_def, conn_table) - for uq_def in conn_uniques) + for uq_def in conn_uniques) conn_indexes = set(_make_index(ix, conn_table) for ix in conn_indexes) # 3. give the dialect a chance to omit indexes and constraints that # we know are either added implicitly by the DB or that the DB # can't accurately report on autogen_context['context'].impl.\ - correct_for_autogen_constraints( - conn_uniques, conn_indexes, - metadata_unique_constraints, - metadata_indexes - ) + correct_for_autogen_constraints( + conn_uniques, conn_indexes, + metadata_unique_constraints, + metadata_indexes + ) # 4. organize the constraints into "signature" collections, the # _constraint_sig() objects provide a consistent facade over both # Index and UniqueConstraint so we can easily work with them # interchangeably metadata_unique_constraints = set(_uq_constraint_sig(uq) - for uq in metadata_unique_constraints - ) + for uq in metadata_unique_constraints + ) metadata_indexes = set(_ix_constraint_sig(ix) for ix in metadata_indexes) @@ -263,16 +274,16 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, # 5. index things by name, for those objects that have names metadata_names = dict( - (c.name, c) for c in - metadata_unique_constraints.union(metadata_indexes) - if c.name is not None) + (c.name, c) for c in + metadata_unique_constraints.union(metadata_indexes) + if c.name is not None) conn_uniques_by_name = dict((c.name, c) for c in conn_unique_constraints) conn_indexes_by_name = dict((c.name, c) for c in conn_indexes) conn_names = dict((c.name, c) for c in - conn_unique_constraints.union(conn_indexes) - if c.name is not None) + conn_unique_constraints.union(conn_indexes) + if c.name is not None) doubled_constraints = dict( (name, (conn_uniques_by_name[name], conn_indexes_by_name[name])) @@ -283,11 +294,11 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, # constraints. conn_uniques_by_sig = dict((uq.sig, uq) for uq in conn_unique_constraints) metadata_uniques_by_sig = dict( - (uq.sig, uq) for uq in metadata_unique_constraints) + (uq.sig, uq) for uq in metadata_unique_constraints) metadata_indexes_by_sig = dict( - (ix.sig, ix) for ix in metadata_indexes) + (ix.sig, ix) for ix in metadata_indexes) unnamed_metadata_uniques = dict((uq.sig, uq) for uq in - metadata_unique_constraints if uq.name is None) + metadata_unique_constraints if uq.name is None) # assumptions: # 1. a unique constraint or an index from the connection *always* @@ -301,10 +312,10 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, if obj.is_index: diffs.append(("add_index", obj.const)) log.info("Detected added index '%s' on %s", - obj.name, ', '.join([ - "'%s'" % obj.column_names - ]) - ) + obj.name, ', '.join([ + "'%s'" % obj.column_names + ]) + ) else: if not supports_unique_constraints: # can't report unique indexes as added if we don't @@ -315,10 +326,10 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, return diffs.append(("add_constraint", obj.const)) log.info("Detected added unique constraint '%s' on %s", - obj.name, ', '.join([ - "'%s'" % obj.column_names - ]) - ) + obj.name, ', '.join([ + "'%s'" % obj.column_names + ]) + ) def obj_removed(obj): if obj.is_index: @@ -333,20 +344,20 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, else: diffs.append(("remove_constraint", obj.const)) log.info("Detected removed unique constraint '%s' on '%s'", - obj.name, tname - ) + obj.name, tname + ) def obj_changed(old, new, msg): if old.is_index: log.info("Detected changed index '%s' on '%s':%s", - old.name, tname, ', '.join(msg) - ) + old.name, tname, ', '.join(msg) + ) diffs.append(("remove_index", old.const)) diffs.append(("add_index", new.const)) else: log.info("Detected changed unique constraint '%s' on '%s':%s", - old.name, tname, ', '.join(msg) - ) + old.name, tname, ', '.join(msg) + ) diffs.append(("remove_constraint", old.const)) diffs.append(("add_constraint", new.const)) @@ -354,7 +365,6 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, obj = metadata_names[added_name] obj_added(obj) - for existing_name in sorted(set(metadata_names).intersection(conn_names)): metadata_obj = metadata_names[existing_name] @@ -384,14 +394,13 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, if msg: obj_changed(conn_obj, metadata_obj, msg) - for removed_name in sorted(set(conn_names).difference(metadata_names)): conn_obj = conn_names[removed_name] if not conn_obj.is_index and conn_obj.sig in unnamed_metadata_uniques: continue elif removed_name in doubled_constraints: if conn_obj.sig not in metadata_indexes_by_sig and \ - conn_obj.sig not in metadata_uniques_by_sig: + conn_obj.sig not in metadata_uniques_by_sig: conn_uq, conn_idx = doubled_constraints[removed_name] obj_removed(conn_uq) obj_removed(conn_idx) @@ -404,8 +413,8 @@ def _compare_indexes_and_uniques(schema, tname, object_filters, conn_table, def _compare_nullable(schema, tname, cname, conn_col, - metadata_col_nullable, diffs, - autogen_context): + metadata_col_nullable, diffs, + autogen_context): conn_col_nullable = conn_col.nullable if conn_col_nullable is not metadata_col_nullable: diffs.append( @@ -418,24 +427,25 @@ def _compare_nullable(schema, tname, cname, conn_col, metadata_col_nullable), ) log.info("Detected %s on column '%s.%s'", - "NULL" if metadata_col_nullable else "NOT NULL", - tname, - cname - ) + "NULL" if metadata_col_nullable else "NOT NULL", + tname, + cname + ) + def _compare_type(schema, tname, cname, conn_col, - metadata_col, diffs, - autogen_context): + metadata_col, diffs, + autogen_context): conn_type = conn_col.type metadata_type = metadata_col.type if conn_type._type_affinity is sqltypes.NullType: log.info("Couldn't determine database type " - "for column '%s.%s'", tname, cname) + "for column '%s.%s'", tname, cname) return if metadata_type._type_affinity is sqltypes.NullType: log.info("Column '%s.%s' has no type within " - "the model; can't compare", tname, cname) + "the model; can't compare", tname, cname) return isdiff = autogen_context['context']._compare_type(conn_col, metadata_col) @@ -444,40 +454,42 @@ def _compare_type(schema, tname, cname, conn_col, diffs.append( ("modify_type", schema, tname, cname, - { - "existing_nullable": conn_col.nullable, - "existing_server_default": conn_col.server_default, - }, - conn_type, - metadata_type), + { + "existing_nullable": conn_col.nullable, + "existing_server_default": conn_col.server_default, + }, + conn_type, + metadata_type), ) log.info("Detected type change from %r to %r on '%s.%s'", - conn_type, metadata_type, tname, cname - ) + conn_type, metadata_type, tname, cname + ) + def _render_server_default_for_compare(metadata_default, - metadata_col, autogen_context): + metadata_col, autogen_context): return _render_server_default( - metadata_default, autogen_context, - repr_=metadata_col.type._type_affinity is sqltypes.String) + metadata_default, autogen_context, + repr_=metadata_col.type._type_affinity is sqltypes.String) + def _compare_server_default(schema, tname, cname, conn_col, metadata_col, - diffs, autogen_context): + diffs, autogen_context): metadata_default = metadata_col.server_default conn_col_default = conn_col.server_default if conn_col_default is None and metadata_default is None: return False rendered_metadata_default = _render_server_default_for_compare( - metadata_default, metadata_col, autogen_context) + metadata_default, metadata_col, autogen_context) rendered_conn_default = conn_col.server_default.arg.text \ - if conn_col.server_default else None + if conn_col.server_default else None isdiff = autogen_context['context']._compare_server_default( - conn_col, metadata_col, - rendered_metadata_default, - rendered_conn_default - ) + conn_col, metadata_col, + rendered_metadata_default, + rendered_conn_default + ) if isdiff: conn_col_default = rendered_conn_default diffs.append( @@ -490,6 +502,6 @@ def _compare_server_default(schema, tname, cname, conn_col, metadata_col, metadata_default), ) log.info("Detected server default on column '%s.%s'", - tname, - cname - ) + tname, + cname + ) diff --git a/alembic/autogenerate/render.py b/alembic/autogenerate/render.py index 81bd774..447870b 100644 --- a/alembic/autogenerate/render.py +++ b/alembic/autogenerate/render.py @@ -10,6 +10,7 @@ MAX_PYTHON_ARGS = 255 try: from sqlalchemy.sql.naming import conv + def _render_gen_name(autogen_context, name): if isinstance(name, conv): return _f_name(_alembic_autogenerate_prefix(autogen_context), name) @@ -19,7 +20,9 @@ except ImportError: def _render_gen_name(autogen_context, name): return name + class _f_name(object): + def __init__(self, prefix, name): self.prefix = prefix self.name = name @@ -27,6 +30,7 @@ class _f_name(object): def __repr__(self): return "%sf(%r)" % (self.prefix, self.name) + def _render_potential_expr(value, autogen_context): if isinstance(value, sql.ClauseElement): if compat.sqla_08: @@ -37,23 +41,24 @@ def _render_potential_expr(value, autogen_context): return "%(prefix)stext(%(sql)r)" % { "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), "sql": str( - value.compile(dialect=autogen_context['dialect'], - **compile_kw) - ) + value.compile(dialect=autogen_context['dialect'], + **compile_kw) + ) } else: return repr(value) + def _add_table(table, autogen_context): args = [col for col in [_render_column(col, autogen_context) for col in table.c] - if col] + \ + if col] + \ sorted([rcons for rcons in - [_render_constraint(cons, autogen_context) for cons in - table.constraints] - if rcons is not None - ]) + [_render_constraint(cons, autogen_context) for cons in + table.constraints] + if rcons is not None + ]) if len(args) > MAX_PYTHON_ARGS: args = '*[' + ',\n'.join(args) + ']' @@ -72,16 +77,18 @@ def _add_table(table, autogen_context): text += "\n)" return text + def _drop_table(table, autogen_context): text = "%(prefix)sdrop_table(%(tname)r" % { - "prefix": _alembic_autogenerate_prefix(autogen_context), - "tname": table.name - } + "prefix": _alembic_autogenerate_prefix(autogen_context), + "tname": table.name + } if table.schema: text += ", schema=%r" % table.schema text += ")" return text + def _add_index(index, autogen_context): """ Generate Alembic operations for the CREATE INDEX of an @@ -90,27 +97,28 @@ def _add_index(index, autogen_context): from .compare import _get_index_column_names text = "%(prefix)screate_index(%(name)r, '%(table)s', %(columns)s, "\ - "unique=%(unique)r%(schema)s%(kwargs)s)" % { - 'prefix': _alembic_autogenerate_prefix(autogen_context), - 'name': _render_gen_name(autogen_context, index.name), - 'table': index.table.name, - 'columns': _get_index_column_names(index), - 'unique': index.unique or False, - 'schema': (", schema='%s'" % index.table.schema) if index.table.schema else '', - 'kwargs': (', '+', '.join( - ["%s=%s" % (key, _render_potential_expr(val, autogen_context)) - for key, val in index.kwargs.items()]))\ + "unique=%(unique)r%(schema)s%(kwargs)s)" % { + 'prefix': _alembic_autogenerate_prefix(autogen_context), + 'name': _render_gen_name(autogen_context, index.name), + 'table': index.table.name, + 'columns': _get_index_column_names(index), + 'unique': index.unique or False, + 'schema': (", schema='%s'" % index.table.schema) if index.table.schema else '', + 'kwargs': (', ' + ', '.join( + ["%s=%s" % (key, _render_potential_expr(val, autogen_context)) + for key, val in index.kwargs.items()])) if len(index.kwargs) else '' - } + } return text + def _drop_index(index, autogen_context): """ Generate Alembic operations for the DROP INDEX of an :class:`~sqlalchemy.schema.Index` instance. """ text = "%(prefix)sdrop_index(%(name)r, "\ - "table_name='%(table_name)s'%(schema)s)" % { + "table_name='%(table_name)s'%(schema)s)" % { 'prefix': _alembic_autogenerate_prefix(autogen_context), 'name': _render_gen_name(autogen_context, index.name), 'table_name': index.table.name, @@ -135,6 +143,7 @@ def _add_unique_constraint(constraint, autogen_context): """ return _uq_constraint(constraint, autogen_context, True) + def _uq_constraint(constraint, autogen_context, alter): opts = [] if constraint.deferrable: @@ -148,13 +157,13 @@ def _uq_constraint(constraint, autogen_context, alter): if alter: args = [repr(_render_gen_name(autogen_context, constraint.name)), - repr(constraint.table.name)] + repr(constraint.table.name)] args.append(repr([col.name for col in constraint.columns])) args.extend(["%s=%r" % (k, v) for k, v in opts]) return "%(prefix)screate_unique_constraint(%(args)s)" % { - 'prefix': _alembic_autogenerate_prefix(autogen_context), - 'args': ", ".join(args) - } + 'prefix': _alembic_autogenerate_prefix(autogen_context), + 'args': ", ".join(args) + } else: args = [repr(col.name) for col in constraint.columns] args.extend(["%s=%r" % (k, v) for k, v in opts]) @@ -167,12 +176,15 @@ def _uq_constraint(constraint, autogen_context, alter): def _add_fk_constraint(constraint, autogen_context): raise NotImplementedError() + def _add_pk_constraint(constraint, autogen_context): raise NotImplementedError() + def _add_check_constraint(constraint, autogen_context): raise NotImplementedError() + def _add_constraint(constraint, autogen_context): """ Dispatcher for the different types of constraints. @@ -186,42 +198,46 @@ def _add_constraint(constraint, autogen_context): } return funcs[constraint.__visit_name__](constraint, autogen_context) + def _drop_constraint(constraint, autogen_context): """ Generate Alembic operations for the ALTER TABLE ... DROP CONSTRAINT of a :class:`~sqlalchemy.schema.UniqueConstraint` instance. """ text = "%(prefix)sdrop_constraint(%(name)r, '%(table_name)s'%(schema)s)" % { - 'prefix': _alembic_autogenerate_prefix(autogen_context), - 'name': _render_gen_name(autogen_context, constraint.name), - 'table_name': constraint.table.name, - 'schema': (", schema='%s'" % constraint.table.schema) - if constraint.table.schema else '', + 'prefix': _alembic_autogenerate_prefix(autogen_context), + 'name': _render_gen_name(autogen_context, constraint.name), + 'table_name': constraint.table.name, + 'schema': (", schema='%s'" % constraint.table.schema) + if constraint.table.schema else '', } return text + def _add_column(schema, tname, column, autogen_context): text = "%(prefix)sadd_column(%(tname)r, %(column)s" % { - "prefix": _alembic_autogenerate_prefix(autogen_context), - "tname": tname, - "column": _render_column(column, autogen_context) - } + "prefix": _alembic_autogenerate_prefix(autogen_context), + "tname": tname, + "column": _render_column(column, autogen_context) + } if schema: text += ", schema=%r" % schema text += ")" return text + def _drop_column(schema, tname, column, autogen_context): text = "%(prefix)sdrop_column(%(tname)r, %(cname)r" % { - "prefix": _alembic_autogenerate_prefix(autogen_context), - "tname": tname, - "cname": column.name - } + "prefix": _alembic_autogenerate_prefix(autogen_context), + "tname": tname, + "cname": column.name + } if schema: text += ", schema=%r" % schema text += ")" return text + def _modify_col(tname, cname, autogen_context, server_default=False, @@ -233,37 +249,38 @@ def _modify_col(tname, cname, schema=None): indent = " " * 11 text = "%(prefix)salter_column(%(tname)r, %(cname)r" % { - 'prefix': _alembic_autogenerate_prefix( - autogen_context), - 'tname': tname, - 'cname': cname} + 'prefix': _alembic_autogenerate_prefix( + autogen_context), + 'tname': tname, + 'cname': cname} text += ",\n%sexisting_type=%s" % (indent, - _repr_type(existing_type, autogen_context)) + _repr_type(existing_type, autogen_context)) if server_default is not False: rendered = _render_server_default( - server_default, autogen_context) + server_default, autogen_context) text += ",\n%sserver_default=%s" % (indent, rendered) if type_ is not None: text += ",\n%stype_=%s" % (indent, - _repr_type(type_, autogen_context)) + _repr_type(type_, autogen_context)) if nullable is not None: text += ",\n%snullable=%r" % ( - indent, nullable,) + indent, nullable,) if existing_nullable is not None: text += ",\n%sexisting_nullable=%r" % ( - indent, existing_nullable) + indent, existing_nullable) if existing_server_default: rendered = _render_server_default( - existing_server_default, - autogen_context) + existing_server_default, + autogen_context) text += ",\n%sexisting_server_default=%s" % ( - indent, rendered) + indent, rendered) if schema: text += ",\n%sschema=%r" % (indent, schema) text += ")" return text + def _user_autogenerate_prefix(autogen_context): prefix = autogen_context['opts']['user_module_prefix'] if prefix is None: @@ -271,12 +288,15 @@ def _user_autogenerate_prefix(autogen_context): else: return prefix + def _sqlalchemy_autogenerate_prefix(autogen_context): return autogen_context['opts']['sqlalchemy_module_prefix'] or '' + def _alembic_autogenerate_prefix(autogen_context): return autogen_context['opts']['alembic_module_prefix'] or '' + def _user_defined_render(type_, object_, autogen_context): if 'opts' in autogen_context and \ 'render_item' in autogen_context['opts']: @@ -287,6 +307,7 @@ def _user_defined_render(type_, object_, autogen_context): return rendered return False + def _render_column(column, autogen_context): rendered = _user_defined_render("column", column, autogen_context) if rendered is not False: @@ -295,8 +316,8 @@ def _render_column(column, autogen_context): opts = [] if column.server_default: rendered = _render_server_default( - column.server_default, autogen_context - ) + column.server_default, autogen_context + ) if rendered: opts.append(("server_default", rendered)) @@ -314,6 +335,7 @@ def _render_column(column, autogen_context): 'kw': ", ".join(["%s=%s" % (kwname, val) for kwname, val in opts]) } + def _render_server_default(default, autogen_context, repr_=True): rendered = _user_defined_render("server_default", default, autogen_context) if rendered is not False: @@ -324,7 +346,7 @@ def _render_server_default(default, autogen_context, repr_=True): default = default.arg else: default = str(default.arg.compile( - dialect=autogen_context['dialect'])) + dialect=autogen_context['dialect'])) if isinstance(default, string_types): if repr_: default = re.sub(r"^'|'$", "", default) @@ -334,6 +356,7 @@ def _render_server_default(default, autogen_context, repr_=True): else: return None + def _repr_type(type_, autogen_context): rendered = _user_defined_render("type", type_, autogen_context) if rendered is not False: @@ -353,6 +376,7 @@ def _repr_type(type_, autogen_context): prefix = _user_autogenerate_prefix(autogen_context) return "%s%r" % (prefix, type_) + def _render_constraint(constraint, autogen_context): renderer = _constraint_renderers.get(type(constraint), None) if renderer: @@ -360,6 +384,7 @@ def _render_constraint(constraint, autogen_context): else: return None + def _render_primary_key(constraint, autogen_context): rendered = _user_defined_render("primary_key", constraint, autogen_context) if rendered is not False: @@ -379,6 +404,7 @@ def _render_primary_key(constraint, autogen_context): ), } + def _fk_colspec(fk, metadata_schema): """Implement a 'safe' version of ForeignKey._get_colspec() that never tries to resolve the remote table. @@ -393,6 +419,7 @@ def _fk_colspec(fk, metadata_schema): colspec = "%s.%s" % (metadata_schema, colspec) return colspec + def _render_foreign_key(constraint, autogen_context): rendered = _user_defined_render("foreign_key", constraint, autogen_context) if rendered is not False: @@ -414,15 +441,16 @@ def _render_foreign_key(constraint, autogen_context): apply_metadata_schema = constraint.parent.metadata.schema return "%(prefix)sForeignKeyConstraint([%(cols)s], "\ - "[%(refcols)s], %(args)s)" % { - "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), - "cols": ", ".join("'%s'" % f.parent.key for f in constraint.elements), - "refcols": ", ".join(repr(_fk_colspec(f, apply_metadata_schema)) - for f in constraint.elements), - "args": ", ".join( - ["%s=%s" % (kwname, val) for kwname, val in opts] - ), - } + "[%(refcols)s], %(args)s)" % { + "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), + "cols": ", ".join("'%s'" % f.parent.key for f in constraint.elements), + "refcols": ", ".join(repr(_fk_colspec(f, apply_metadata_schema)) + for f in constraint.elements), + "args": ", ".join( + ["%s=%s" % (kwname, val) for kwname, val in opts] + ), + } + def _render_check_constraint(constraint, autogen_context): rendered = _user_defined_render("check", constraint, autogen_context) @@ -436,21 +464,21 @@ def _render_check_constraint(constraint, autogen_context): if constraint._create_rule and \ hasattr(constraint._create_rule, 'target') and \ isinstance(constraint._create_rule.target, - sqltypes.TypeEngine): + sqltypes.TypeEngine): return None opts = [] if constraint.name: opts.append(("name", repr(_render_gen_name(autogen_context, constraint.name)))) return "%(prefix)sCheckConstraint(%(sqltext)r%(opts)s)" % { - "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), - "opts": ", " + (", ".join("%s=%s" % (k, v) - for k, v in opts)) if opts else "", - "sqltext": str( + "prefix": _sqlalchemy_autogenerate_prefix(autogen_context), + "opts": ", " + (", ".join("%s=%s" % (k, v) + for k, v in opts)) if opts else "", + "sqltext": str( constraint.sqltext.compile( dialect=autogen_context['dialect'] ) - ) - } + ) + } _constraint_renderers = { sa_schema.PrimaryKeyConstraint: _render_primary_key, diff --git a/alembic/command.py b/alembic/command.py index f1c5962..a6d7995 100644 --- a/alembic/command.py +++ b/alembic/command.py @@ -4,21 +4,23 @@ from .script import ScriptDirectory from .environment import EnvironmentContext from . import util, autogenerate as autogen + def list_templates(config): """List available templates""" config.print_stdout("Available templates:\n") for tempname in os.listdir(config.get_template_directory()): with open(os.path.join( - config.get_template_directory(), - tempname, - 'README')) as readme: + config.get_template_directory(), + tempname, + 'README')) as readme: synopsis = next(readme) config.print_stdout("%s - %s", tempname, synopsis) config.print_stdout("\nTemplates are used via the 'init' command, e.g.:") config.print_stdout("\n alembic init --template pylons ./scripts") + def init(config, directory, template='generic'): """Initialize a new scripts directory.""" @@ -26,7 +28,7 @@ def init(config, directory, template='generic'): raise util.CommandError("Directory %s already exists" % directory) template_dir = os.path.join(config.get_template_directory(), - template) + template) if not os.access(template_dir, os.F_OK): raise util.CommandError("No such template %r" % template) @@ -58,8 +60,9 @@ def init(config, directory, template='generic'): output_file ) - util.msg("Please edit configuration/connection/logging "\ - "settings in %r before proceeding." % config_file) + util.msg("Please edit configuration/connection/logging " + "settings in %r before proceeding." % config_file) + def revision(config, message=None, autogenerate=False, sql=False): """Create a new revision file.""" @@ -77,6 +80,7 @@ def revision(config, message=None, autogenerate=False, sql=False): if autogenerate: environment = True + def retrieve_migrations(rev, context): if script.get_revision(rev) is not script.get_revision("head"): raise util.CommandError("Target database is not up to date.") @@ -124,6 +128,7 @@ def upgrade(config, revision, sql=False, tag=None): ): script.run_env() + def downgrade(config, revision, sql=False, tag=None): """Revert to a previous version.""" @@ -150,6 +155,7 @@ def downgrade(config, revision, sql=False, tag=None): ): script.run_env() + def history(config, rev_range=None): """List changeset scripts in chronological order.""" @@ -157,16 +163,16 @@ def history(config, rev_range=None): if rev_range is not None: if ":" not in rev_range: raise util.CommandError( - "History range requires [start]:[end], " - "[start]:, or :[end]") + "History range requires [start]:[end], " + "[start]:, or :[end]") base, head = rev_range.strip().split(":") else: base = head = None def _display_history(config, script, base, head): for sc in script.walk_revisions( - base=base or "base", - head=head or "head"): + base=base or "base", + head=head or "head"): if sc.is_head: config.print_stdout("") config.print_stdout(sc.log_entry) @@ -202,14 +208,16 @@ def branches(config): config.print_stdout(sc) for rev in sc.nextrev: config.print_stdout("%s -> %s", - " " * len(str(sc.down_revision)), - script.get_revision(rev) - ) + " " * len(str(sc.down_revision)), + script.get_revision(rev) + ) + def current(config, head_only=False): """Display the current revision for each database.""" script = ScriptDirectory.from_config(config) + def display_version(rev, context): rev = script.get_revision(rev) @@ -232,11 +240,13 @@ def current(config, head_only=False): ): script.run_env() + def stamp(config, revision, sql=False, tag=None): """'stamp' the revision table with the given revision; don't run any migrations.""" script = ScriptDirectory.from_config(config) + def do_stamp(rev, context): if sql: current = False @@ -257,6 +267,7 @@ def stamp(config, revision, sql=False, tag=None): ): script.run_env() + def splice(config, parent, child): """'splice' two branches, creating a new revision file. diff --git a/alembic/compat.py b/alembic/compat.py index aac0560..cded54b 100644 --- a/alembic/compat.py +++ b/alembic/compat.py @@ -17,6 +17,7 @@ if py3k: string_types = str, binary_type = bytes text_type = str + def callable(fn): return hasattr(fn, '__call__') @@ -45,6 +46,7 @@ if py2k: if py33: from importlib import machinery + def load_module_py(module_id, path): return machinery.SourceFileLoader(module_id, path).load_module(module_id) @@ -53,6 +55,7 @@ if py33: else: import imp + def load_module_py(module_id, path): with open(path, 'rb') as fp: mod = imp.load_source(module_id, path, fp) @@ -78,6 +81,8 @@ except AttributeError: ################################################ # cross-compatible metaclass implementation # Copyright (c) 2010-2012 Benjamin Peterson + + def with_metaclass(meta, base=object): """Create a base class with a metaclass.""" return meta("%sBase" % meta.__name__, (base,), {}) @@ -88,6 +93,7 @@ def with_metaclass(meta, base=object): # into a given buffer, but doesn't close it. # not sure of a more idiomatic approach to this. class EncodedIO(io.TextIOWrapper): + def close(self): pass @@ -99,10 +105,12 @@ if py2k: # adapter. class ActLikePy3kIO(object): + """Produce an object capable of wrapping either sys.stdout (e.g. file) *or* StringIO.StringIO(). """ + def _false(self): return False @@ -123,8 +131,7 @@ if py2k: return self.file_.flush() class EncodedIO(EncodedIO): + def __init__(self, file_, encoding): super(EncodedIO, self).__init__( - ActLikePy3kIO(file_), encoding=encoding) - - + ActLikePy3kIO(file_), encoding=encoding) diff --git a/alembic/config.py b/alembic/config.py index 86ff1df..003949b 100644 --- a/alembic/config.py +++ b/alembic/config.py @@ -6,7 +6,9 @@ import sys from . import command, util, package_dir, compat + class Config(object): + """Represent an Alembic configuration. Within an ``env.py`` script, this is available @@ -50,8 +52,9 @@ class Config(object): ..versionadded:: 0.4 """ + def __init__(self, file_=None, ini_section='alembic', output_buffer=None, - stdout=sys.stdout, cmd_opts=None): + stdout=sys.stdout, cmd_opts=None): """Construct a new :class:`.Config` """ @@ -90,9 +93,9 @@ class Config(object): """Render a message to standard out.""" util.write_outstream( - self.stdout, - (compat.text_type(text) % arg), - "\n" + self.stdout, + (compat.text_type(text) % arg), + "\n" ) @util.memoized_property @@ -162,8 +165,8 @@ class Config(object): """ if not self.file_config.has_section(section): raise util.CommandError("No config file %r found, or file has no " - "'[%s]' section" % - (self.config_file_name, section)) + "'[%s]' section" % + (self.config_file_name, section)) if self.file_config.has_option(section, name): return self.file_config.get(section, name) else: @@ -181,35 +184,35 @@ class Config(object): class CommandLine(object): + def __init__(self, prog=None): self._generate_args(prog) - def _generate_args(self, prog): def add_options(parser, positional, kwargs): if 'template' in kwargs: parser.add_argument("-t", "--template", - default='generic', - type=str, - help="Setup template for use with 'init'") + default='generic', + type=str, + help="Setup template for use with 'init'") if 'message' in kwargs: parser.add_argument("-m", "--message", - type=str, - help="Message string to use with 'revision'") + type=str, + help="Message string to use with 'revision'") if 'sql' in kwargs: parser.add_argument("--sql", - action="store_true", - help="Don't emit SQL to database - dump to " - "standard output/file instead") + action="store_true", + help="Don't emit SQL to database - dump to " + "standard output/file instead") if 'tag' in kwargs: parser.add_argument("--tag", - type=str, - help="Arbitrary 'tag' name - can be used by " - "custom env.py scripts.") + type=str, + help="Arbitrary 'tag' name - can be used by " + "custom env.py scripts.") if 'autogenerate' in kwargs: parser.add_argument("--autogenerate", - action="store_true", - help="Populate revision script with candidate " + action="store_true", + help="Populate revision script with candidate " "migration operations, based on comparison " "of database to model.") # "current" command @@ -225,7 +228,6 @@ class CommandLine(object): help="Specify a revision range; " "format is [start]:[end]") - positional_help = { 'directory': "location of scripts directory", 'revision': "revision identifier" @@ -252,8 +254,8 @@ class CommandLine(object): for fn in [getattr(command, n) for n in dir(command)]: if inspect.isfunction(fn) and \ - fn.__name__[0] != '_' and \ - fn.__module__ == 'alembic.command': + fn.__name__[0] != '_' and \ + fn.__module__ == 'alembic.command': spec = inspect.getargspec(fn) if spec[3]: @@ -264,8 +266,8 @@ class CommandLine(object): kwarg = [] subparser = subparsers.add_parser( - fn.__name__, - help=fn.__doc__) + fn.__name__, + help=fn.__doc__) add_options(subparser, positional, kwarg) subparser.set_defaults(cmd=(fn, positional, kwarg)) self.parser = parser @@ -275,9 +277,9 @@ class CommandLine(object): try: fn(config, - *[getattr(options, k) for k in positional], - **dict((k, getattr(options, k)) for k in kwarg) - ) + *[getattr(options, k) for k in positional], + **dict((k, getattr(options, k)) for k in kwarg) + ) except util.CommandError as e: util.err(str(e)) @@ -289,13 +291,14 @@ class CommandLine(object): self.parser.error("too few arguments") else: cfg = Config(file_=options.config, - ini_section=options.name, cmd_opts=options) + ini_section=options.name, cmd_opts=options) self.run_cmd(cfg, options) + def main(argv=None, prog=None, **kwargs): """The console runner function for Alembic.""" CommandLine(prog=prog).main(argv=argv) if __name__ == '__main__': - main()
\ No newline at end of file + main() diff --git a/alembic/ddl/base.py b/alembic/ddl/base.py index 5d703a5..3a60926 100644 --- a/alembic/ddl/base.py +++ b/alembic/ddl/base.py @@ -5,62 +5,81 @@ from sqlalchemy.schema import DDLElement, Column from sqlalchemy import Integer from sqlalchemy import types as sqltypes + class AlterTable(DDLElement): + """Represent an ALTER TABLE statement. Only the string name and optional schema name of the table is required, not a full Table object. """ + def __init__(self, table_name, schema=None): self.table_name = table_name self.schema = schema + class RenameTable(AlterTable): + def __init__(self, old_table_name, new_table_name, schema=None): super(RenameTable, self).__init__(old_table_name, schema=schema) self.new_table_name = new_table_name + class AlterColumn(AlterTable): + def __init__(self, name, column_name, schema=None, - existing_type=None, - existing_nullable=None, - existing_server_default=None): + existing_type=None, + existing_nullable=None, + existing_server_default=None): super(AlterColumn, self).__init__(name, schema=schema) self.column_name = column_name - self.existing_type=sqltypes.to_instance(existing_type) \ - if existing_type is not None else None - self.existing_nullable=existing_nullable - self.existing_server_default=existing_server_default + self.existing_type = sqltypes.to_instance(existing_type) \ + if existing_type is not None else None + self.existing_nullable = existing_nullable + self.existing_server_default = existing_server_default + class ColumnNullable(AlterColumn): + def __init__(self, name, column_name, nullable, **kw): super(ColumnNullable, self).__init__(name, column_name, - **kw) + **kw) self.nullable = nullable + class ColumnType(AlterColumn): + def __init__(self, name, column_name, type_, **kw): super(ColumnType, self).__init__(name, column_name, - **kw) + **kw) self.type_ = sqltypes.to_instance(type_) + class ColumnName(AlterColumn): + def __init__(self, name, column_name, newname, **kw): super(ColumnName, self).__init__(name, column_name, **kw) self.newname = newname + class ColumnDefault(AlterColumn): + def __init__(self, name, column_name, default, **kw): super(ColumnDefault, self).__init__(name, column_name, **kw) self.default = default + class AddColumn(AlterTable): + def __init__(self, name, column, schema=None): super(AddColumn, self).__init__(name, schema=schema) self.column = column + class DropColumn(AlterTable): + def __init__(self, name, column, schema=None): super(DropColumn, self).__init__(name, schema=schema) self.column = column @@ -73,6 +92,7 @@ def visit_rename_table(element, compiler, **kw): format_table_name(compiler, element.new_table_name, element.schema) ) + @compiles(AddColumn) def visit_add_column(element, compiler, **kw): return "%s %s" % ( @@ -80,6 +100,7 @@ def visit_add_column(element, compiler, **kw): add_column(compiler, element.column, **kw) ) + @compiles(DropColumn) def visit_drop_column(element, compiler, **kw): return "%s %s" % ( @@ -87,6 +108,7 @@ def visit_drop_column(element, compiler, **kw): drop_column(compiler, element.column.name, **kw) ) + @compiles(ColumnNullable) def visit_column_nullable(element, compiler, **kw): return "%s %s %s" % ( @@ -95,6 +117,7 @@ def visit_column_nullable(element, compiler, **kw): "DROP NOT NULL" if element.nullable else "SET NOT NULL" ) + @compiles(ColumnType) def visit_column_type(element, compiler, **kw): return "%s %s %s" % ( @@ -103,6 +126,7 @@ def visit_column_type(element, compiler, **kw): "TYPE %s" % format_type(compiler, element.type_) ) + @compiles(ColumnName) def visit_column_name(element, compiler, **kw): return "%s RENAME %s TO %s" % ( @@ -111,23 +135,26 @@ def visit_column_name(element, compiler, **kw): format_column_name(compiler, element.newname) ) + @compiles(ColumnDefault) def visit_column_default(element, compiler, **kw): return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), "SET DEFAULT %s" % - format_server_default(compiler, element.default) + format_server_default(compiler, element.default) if element.default is not None else "DROP DEFAULT" ) + def quote_dotted(name, quote): """quote the elements of a dotted name""" result = '.'.join([quote(x) for x in name.split('.')]) return result + def format_table_name(compiler, name, schema): quote = functools.partial(compiler.preparer.quote, force=None) if schema: @@ -135,27 +162,32 @@ def format_table_name(compiler, name, schema): else: return quote(name) + def format_column_name(compiler, name): return compiler.preparer.quote(name, None) + def format_server_default(compiler, default): return compiler.get_column_default_string( - Column("x", Integer, server_default=default) - ) + Column("x", Integer, server_default=default) + ) + def format_type(compiler, type_): return compiler.dialect.type_compiler.process(type_) + def alter_table(compiler, name, schema): return "ALTER TABLE %s" % format_table_name(compiler, name, schema) + def drop_column(compiler, name): return 'DROP COLUMN %s' % format_column_name(compiler, name) + def alter_column(compiler, name): return 'ALTER COLUMN %s' % format_column_name(compiler, name) + def add_column(compiler, column, **kw): return "ADD COLUMN %s" % compiler.get_column_specification(column, **kw) - - diff --git a/alembic/ddl/impl.py b/alembic/ddl/impl.py index 664158f..a22a4fb 100644 --- a/alembic/ddl/impl.py +++ b/alembic/ddl/impl.py @@ -8,7 +8,9 @@ from ..compat import string_types, text_type, with_metaclass from .. import util from . import base + class ImplMeta(type): + def __init__(cls, classname, bases, dict_): newtype = type.__init__(cls, classname, bases, dict_) if '__dialect__' in dict_: @@ -17,7 +19,9 @@ class ImplMeta(type): _impls = {} + class DefaultImpl(with_metaclass(ImplMeta)): + """Provide the entrypoint for major migration operations, including database-specific behavioral variances. @@ -35,8 +39,8 @@ class DefaultImpl(with_metaclass(ImplMeta)): command_terminator = ";" def __init__(self, dialect, connection, as_sql, - transactional_ddl, output_buffer, - context_opts): + transactional_ddl, output_buffer, + context_opts): self.dialect = dialect self.connection = connection self.as_sql = as_sql @@ -59,8 +63,8 @@ class DefaultImpl(with_metaclass(ImplMeta)): return self.connection def _exec(self, construct, execution_options=None, - multiparams=(), - params=util.immutabledict()): + multiparams=(), + params=util.immutabledict()): if isinstance(construct, string_types): construct = text(construct) if self.as_sql: @@ -68,8 +72,8 @@ class DefaultImpl(with_metaclass(ImplMeta)): # TODO: coverage raise Exception("Execution arguments not allowed with as_sql") self.static_output(text_type( - construct.compile(dialect=self.dialect) - ).replace("\t", " ").strip() + self.command_terminator) + construct.compile(dialect=self.dialect) + ).replace("\t", " ").strip() + self.command_terminator) else: conn = self.connection if execution_options: @@ -80,49 +84,49 @@ class DefaultImpl(with_metaclass(ImplMeta)): self._exec(sql, execution_options) def alter_column(self, table_name, column_name, - nullable=None, - server_default=False, - name=None, - type_=None, - schema=None, - autoincrement=None, - existing_type=None, - existing_server_default=None, - existing_nullable=None, - existing_autoincrement=None - ): + nullable=None, + server_default=False, + name=None, + type_=None, + schema=None, + autoincrement=None, + existing_type=None, + existing_server_default=None, + existing_nullable=None, + existing_autoincrement=None + ): if autoincrement is not None or existing_autoincrement is not None: util.warn("nautoincrement and existing_autoincrement only make sense for MySQL") if nullable is not None: self._exec(base.ColumnNullable(table_name, column_name, - nullable, schema=schema, - existing_type=existing_type, - existing_server_default=existing_server_default, - existing_nullable=existing_nullable, - )) + nullable, schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + )) if server_default is not False: self._exec(base.ColumnDefault( - table_name, column_name, server_default, - schema=schema, - existing_type=existing_type, - existing_server_default=existing_server_default, - existing_nullable=existing_nullable, - )) + table_name, column_name, server_default, + schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + )) if type_ is not None: self._exec(base.ColumnType( - table_name, column_name, type_, schema=schema, - existing_type=existing_type, - existing_server_default=existing_server_default, - existing_nullable=existing_nullable, - )) + table_name, column_name, type_, schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + )) # do the new name last ;) if name is not None: self._exec(base.ColumnName( - table_name, column_name, name, schema=schema, - existing_type=existing_type, - existing_server_default=existing_server_default, - existing_nullable=existing_nullable, - )) + table_name, column_name, name, schema=schema, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + )) def add_column(self, table_name, column, schema=None): self._exec(base.AddColumn(table_name, column, schema=schema)) @@ -132,7 +136,7 @@ class DefaultImpl(with_metaclass(ImplMeta)): def add_constraint(self, const): if const._create_rule is None or \ - const._create_rule(self): + const._create_rule(self): self._exec(schema.AddConstraint(const)) def drop_constraint(self, const): @@ -140,18 +144,18 @@ class DefaultImpl(with_metaclass(ImplMeta)): def rename_table(self, old_table_name, new_table_name, schema=None): self._exec(base.RenameTable(old_table_name, - new_table_name, schema=schema)) + new_table_name, schema=schema)) def create_table(self, table): if util.sqla_07: table.dispatch.before_create(table, self.connection, - checkfirst=False, - _ddl_runner=self) + checkfirst=False, + _ddl_runner=self) self._exec(schema.CreateTable(table)) if util.sqla_07: table.dispatch.after_create(table, self.connection, checkfirst=False, - _ddl_runner=self) + _ddl_runner=self) for index in table.indexes: self._exec(schema.CreateIndex(index)) @@ -200,8 +204,8 @@ class DefaultImpl(with_metaclass(ImplMeta)): metadata_impl.__dict__.pop('_type_affinity', None) if conn_type._compare_type_affinity( - metadata_impl - ): + metadata_impl + ): comparator = _type_comparators.get(conn_type._type_affinity, None) return comparator and comparator(metadata_type, conn_type) @@ -209,9 +213,9 @@ class DefaultImpl(with_metaclass(ImplMeta)): return True def compare_server_default(self, inspector_column, - metadata_column, - rendered_metadata_default, - rendered_inspector_default): + metadata_column, + rendered_metadata_default, + rendered_inspector_default): return rendered_inspector_default != rendered_metadata_default def correct_for_autogen_constraints(self, conn_uniques, conn_indexes, @@ -247,9 +251,11 @@ class DefaultImpl(with_metaclass(ImplMeta)): """ self.static_output("COMMIT" + self.command_terminator) + class _literal_bindparam(_BindParamClause): pass + @compiles(_literal_bindparam) def _render_literal_bindparam(element, compiler, **kw): return compiler.render_literal_bindparam(element, **kw) @@ -268,6 +274,7 @@ def _textual_index_column(table, text_): class _textual_index_element(sql.ColumnElement): + """Wrap around a sqlalchemy text() construct in such a way that we appear like a column-oriented SQL expression to an Index construct. @@ -305,21 +312,18 @@ def _string_compare(t1, t2): t1.length is not None and \ t1.length != t2.length + def _numeric_compare(t1, t2): return \ ( - t1.precision is not None and \ + t1.precision is not None and t1.precision != t2.precision ) or \ ( - t1.scale is not None and \ + t1.scale is not None and t1.scale != t2.scale ) _type_comparators = { - sqltypes.String:_string_compare, - sqltypes.Numeric:_numeric_compare + sqltypes.String: _string_compare, + sqltypes.Numeric: _numeric_compare } - - - - diff --git a/alembic/ddl/mssql.py b/alembic/ddl/mssql.py index a3c67d6..d6c835c 100644 --- a/alembic/ddl/mssql.py +++ b/alembic/ddl/mssql.py @@ -4,9 +4,10 @@ from .. import util from .impl import DefaultImpl from .base import alter_table, AddColumn, ColumnName, RenameTable,\ format_table_name, format_column_name, ColumnNullable, alter_column,\ - format_server_default,ColumnDefault, format_type, ColumnType + format_server_default, ColumnDefault, format_type, ColumnType from sqlalchemy.sql.expression import ClauseElement, Executable + class MSSQLImpl(DefaultImpl): __dialect__ = 'mssql' transactional_ddl = True @@ -15,8 +16,8 @@ class MSSQLImpl(DefaultImpl): def __init__(self, *arg, **kw): super(MSSQLImpl, self).__init__(*arg, **kw) self.batch_separator = self.context_opts.get( - "mssql_batch_separator", - self.batch_separator) + "mssql_batch_separator", + self.batch_separator) def _exec(self, construct, *args, **kw): super(MSSQLImpl, self)._exec(construct, *args, **kw) @@ -32,17 +33,17 @@ class MSSQLImpl(DefaultImpl): self.static_output(self.batch_separator) def alter_column(self, table_name, column_name, - nullable=None, - server_default=False, - name=None, - type_=None, - schema=None, - autoincrement=None, - existing_type=None, - existing_server_default=None, - existing_nullable=None, - existing_autoincrement=None - ): + nullable=None, + server_default=False, + name=None, + type_=None, + schema=None, + autoincrement=None, + existing_type=None, + existing_server_default=None, + existing_nullable=None, + existing_autoincrement=None + ): if nullable is not None and existing_type is None: if type_ is not None: @@ -52,70 +53,69 @@ class MSSQLImpl(DefaultImpl): type_ = None else: raise util.CommandError( - "MS-SQL ALTER COLUMN operations " - "with NULL or NOT NULL require the " - "existing_type or a new type_ be passed.") + "MS-SQL ALTER COLUMN operations " + "with NULL or NOT NULL require the " + "existing_type or a new type_ be passed.") super(MSSQLImpl, self).alter_column( - table_name, column_name, - nullable=nullable, - type_=type_, - schema=schema, - autoincrement=autoincrement, - existing_type=existing_type, - existing_nullable=existing_nullable, - existing_autoincrement=existing_autoincrement + table_name, column_name, + nullable=nullable, + type_=type_, + schema=schema, + autoincrement=autoincrement, + existing_type=existing_type, + existing_nullable=existing_nullable, + existing_autoincrement=existing_autoincrement ) if server_default is not False: if existing_server_default is not False or \ - server_default is None: + server_default is None: self._exec( _ExecDropConstraint( - table_name, column_name, - 'sys.default_constraints') + table_name, column_name, + 'sys.default_constraints') ) if server_default is not None: super(MSSQLImpl, self).alter_column( - table_name, column_name, - schema=schema, - server_default=server_default) + table_name, column_name, + schema=schema, + server_default=server_default) if name is not None: super(MSSQLImpl, self).alter_column( - table_name, column_name, - schema=schema, - name=name) + table_name, column_name, + schema=schema, + name=name) def bulk_insert(self, table, rows, **kw): if self.as_sql: self._exec( "SET IDENTITY_INSERT %s ON" % - self.dialect.identifier_preparer.format_table(table) + self.dialect.identifier_preparer.format_table(table) ) super(MSSQLImpl, self).bulk_insert(table, rows, **kw) self._exec( "SET IDENTITY_INSERT %s OFF" % - self.dialect.identifier_preparer.format_table(table) + self.dialect.identifier_preparer.format_table(table) ) else: super(MSSQLImpl, self).bulk_insert(table, rows, **kw) - def drop_column(self, table_name, column, **kw): drop_default = kw.pop('mssql_drop_default', False) if drop_default: self._exec( _ExecDropConstraint( - table_name, column, - 'sys.default_constraints') + table_name, column, + 'sys.default_constraints') ) drop_check = kw.pop('mssql_drop_check', False) if drop_check: self._exec( _ExecDropConstraint( - table_name, column, - 'sys.check_constraints') + table_name, column, + 'sys.check_constraints') ) drop_fks = kw.pop('mssql_drop_foreign_key', False) if drop_fks: @@ -124,13 +124,17 @@ class MSSQLImpl(DefaultImpl): ) super(MSSQLImpl, self).drop_column(table_name, column) + class _ExecDropConstraint(Executable, ClauseElement): + def __init__(self, tname, colname, type_): self.tname = tname self.colname = colname self.type_ = type_ + class _ExecDropFKConstraint(Executable, ClauseElement): + def __init__(self, tname, colname): self.tname = tname self.colname = colname @@ -152,6 +156,7 @@ exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % { 'tname_quoted': format_table_name(compiler, tname, None), } + @compiles(_ExecDropFKConstraint, 'mssql') def _exec_drop_col_fk_constraint(element, compiler, **kw): tname, colname = element.tname, element.colname @@ -169,7 +174,6 @@ exec('alter table %(tname_quoted)s drop constraint ' + @const_name)""" % { } - @compiles(AddColumn, 'mssql') def visit_add_column(element, compiler, **kw): return "%s %s" % ( @@ -177,9 +181,11 @@ def visit_add_column(element, compiler, **kw): mssql_add_column(compiler, element.column, **kw) ) + def mssql_add_column(compiler, column, **kw): return "ADD %s" % compiler.get_column_specification(column, **kw) + @compiles(ColumnNullable, 'mssql') def visit_column_nullable(element, compiler, **kw): return "%s %s %s %s" % ( @@ -189,6 +195,7 @@ def visit_column_nullable(element, compiler, **kw): "NULL" if element.nullable else "NOT NULL" ) + @compiles(ColumnDefault, 'mssql') def visit_column_default(element, compiler, **kw): # TODO: there can also be a named constraint @@ -199,6 +206,7 @@ def visit_column_default(element, compiler, **kw): format_column_name(compiler, element.column_name) ) + @compiles(ColumnName, 'mssql') def visit_rename_column(element, compiler, **kw): return "EXEC sp_rename '%s.%s', %s, 'COLUMN'" % ( @@ -207,6 +215,7 @@ def visit_rename_column(element, compiler, **kw): format_column_name(compiler, element.newname) ) + @compiles(ColumnType, 'mssql') def visit_column_type(element, compiler, **kw): return "%s %s %s" % ( @@ -215,6 +224,7 @@ def visit_column_type(element, compiler, **kw): format_type(compiler, element.type_) ) + @compiles(RenameTable, 'mssql') def visit_rename_table(element, compiler, **kw): return "EXEC sp_rename '%s', %s" % ( diff --git a/alembic/ddl/mysql.py b/alembic/ddl/mysql.py index 58d5c70..7545df7 100644 --- a/alembic/ddl/mysql.py +++ b/alembic/ddl/mysql.py @@ -6,27 +6,28 @@ from ..compat import string_types from .. import util from .impl import DefaultImpl from .base import ColumnNullable, ColumnName, ColumnDefault, \ - ColumnType, AlterColumn, format_column_name, \ - format_server_default + ColumnType, AlterColumn, format_column_name, \ + format_server_default from .base import alter_table + class MySQLImpl(DefaultImpl): __dialect__ = 'mysql' transactional_ddl = False def alter_column(self, table_name, column_name, - nullable=None, - server_default=False, - name=None, - type_=None, - schema=None, - autoincrement=None, - existing_type=None, - existing_server_default=None, - existing_nullable=None, - existing_autoincrement=None - ): + nullable=None, + server_default=False, + name=None, + type_=None, + schema=None, + autoincrement=None, + existing_type=None, + existing_server_default=None, + existing_nullable=None, + existing_autoincrement=None + ): if name is not None: self._exec( MySQLChangeColumn( @@ -34,33 +35,33 @@ class MySQLImpl(DefaultImpl): schema=schema, newname=name, nullable=nullable if nullable is not None else - existing_nullable - if existing_nullable is not None - else True, + existing_nullable + if existing_nullable is not None + else True, type_=type_ if type_ is not None else existing_type, default=server_default if server_default is not False - else existing_server_default, + else existing_server_default, autoincrement=autoincrement if autoincrement is not None - else existing_autoincrement + else existing_autoincrement ) ) elif nullable is not None or \ - type_ is not None or \ - autoincrement is not None: + type_ is not None or \ + autoincrement is not None: self._exec( MySQLModifyColumn( table_name, column_name, schema=schema, newname=name if name is not None else column_name, nullable=nullable if nullable is not None else - existing_nullable - if existing_nullable is not None - else True, + existing_nullable + if existing_nullable is not None + else True, type_=type_ if type_ is not None else existing_type, default=server_default if server_default is not False - else existing_server_default, + else existing_server_default, autoincrement=autoincrement if autoincrement is not None - else existing_autoincrement + else existing_autoincrement ) ) elif server_default is not False: @@ -99,7 +100,9 @@ class MySQLImpl(DefaultImpl): if idx.name in removed: metadata_indexes.remove(idx) + class MySQLAlterDefault(AlterColumn): + def __init__(self, name, column_name, default, schema=None): super(AlterColumn, self).__init__(name, schema=schema) self.column_name = column_name @@ -107,12 +110,13 @@ class MySQLAlterDefault(AlterColumn): class MySQLChangeColumn(AlterColumn): + def __init__(self, name, column_name, schema=None, - newname=None, - type_=None, - nullable=None, - default=False, - autoincrement=None): + newname=None, + type_=None, + nullable=None, + default=False, + autoincrement=None): super(AlterColumn, self).__init__(name, schema=schema) self.column_name = column_name self.nullable = nullable @@ -127,6 +131,7 @@ class MySQLChangeColumn(AlterColumn): self.type_ = sqltypes.to_instance(type_) + class MySQLModifyColumn(MySQLChangeColumn): pass @@ -137,8 +142,8 @@ class MySQLModifyColumn(MySQLChangeColumn): @compiles(ColumnType, 'mysql') def _mysql_doesnt_support_individual(element, compiler, **kw): raise NotImplementedError( - "Individual alter column constructs not supported by MySQL" - ) + "Individual alter column constructs not supported by MySQL" + ) @compiles(MySQLAlterDefault, "mysql") @@ -147,10 +152,11 @@ def _mysql_alter_default(element, compiler, **kw): alter_table(compiler, element.table_name, element.schema), format_column_name(compiler, element.column_name), "SET DEFAULT %s" % format_server_default(compiler, element.default) - if element.default is not None - else "DROP DEFAULT" + if element.default is not None + else "DROP DEFAULT" ) + @compiles(MySQLModifyColumn, "mysql") def _mysql_modify_column(element, compiler, **kw): return "%s MODIFY %s %s" % ( @@ -181,14 +187,16 @@ def _mysql_change_column(element, compiler, **kw): ), ) + def _render_value(compiler, expr): if isinstance(expr, string_types): return "'%s'" % expr else: return compiler.sql_compiler.process(expr) + def _mysql_colspec(compiler, nullable, server_default, type_, - autoincrement): + autoincrement): spec = "%s %s" % ( compiler.dialect.type_compiler.process(type_), "NULL" if nullable else "NOT NULL" @@ -200,6 +208,7 @@ def _mysql_colspec(compiler, nullable, server_default, type_, return spec + @compiles(schema.DropConstraint, "mysql") def _mysql_drop_constraint(element, compiler, **kw): """Redefine SQLAlchemy's drop constraint to @@ -207,15 +216,14 @@ def _mysql_drop_constraint(element, compiler, **kw): constraint = element.element if isinstance(constraint, (schema.ForeignKeyConstraint, - schema.PrimaryKeyConstraint, - schema.UniqueConstraint) - ): + schema.PrimaryKeyConstraint, + schema.UniqueConstraint) + ): return compiler.visit_drop_constraint(element, **kw) elif isinstance(constraint, schema.CheckConstraint): raise NotImplementedError( - "MySQL does not support CHECK constraints.") + "MySQL does not support CHECK constraints.") else: raise NotImplementedError( - "No generic 'DROP CONSTRAINT' in MySQL - " - "please specify constraint type") - + "No generic 'DROP CONSTRAINT' in MySQL - " + "please specify constraint type") diff --git a/alembic/ddl/oracle.py b/alembic/ddl/oracle.py index 28eb246..93e71e5 100644 --- a/alembic/ddl/oracle.py +++ b/alembic/ddl/oracle.py @@ -3,7 +3,8 @@ from sqlalchemy.ext.compiler import compiles from .impl import DefaultImpl from .base import alter_table, AddColumn, ColumnName, \ format_column_name, ColumnNullable, \ - format_server_default,ColumnDefault, format_type, ColumnType + format_server_default, ColumnDefault, format_type, ColumnType + class OracleImpl(DefaultImpl): __dialect__ = 'oracle' @@ -14,8 +15,8 @@ class OracleImpl(DefaultImpl): def __init__(self, *arg, **kw): super(OracleImpl, self).__init__(*arg, **kw) self.batch_separator = self.context_opts.get( - "oracle_batch_separator", - self.batch_separator) + "oracle_batch_separator", + self.batch_separator) def _exec(self, construct, *args, **kw): super(OracleImpl, self)._exec(construct, *args, **kw) @@ -28,6 +29,7 @@ class OracleImpl(DefaultImpl): def emit_commit(self): self._exec("COMMIT") + @compiles(AddColumn, 'oracle') def visit_add_column(element, compiler, **kw): return "%s %s" % ( @@ -35,6 +37,7 @@ def visit_add_column(element, compiler, **kw): add_column(compiler, element.column, **kw), ) + @compiles(ColumnNullable, 'oracle') def visit_column_nullable(element, compiler, **kw): return "%s %s %s" % ( @@ -43,6 +46,7 @@ def visit_column_nullable(element, compiler, **kw): "NULL" if element.nullable else "NOT NULL" ) + @compiles(ColumnType, 'oracle') def visit_column_type(element, compiler, **kw): return "%s %s %s" % ( @@ -51,6 +55,7 @@ def visit_column_type(element, compiler, **kw): "%s" % format_type(compiler, element.type_) ) + @compiles(ColumnName, 'oracle') def visit_column_name(element, compiler, **kw): return "%s RENAME COLUMN %s TO %s" % ( @@ -59,19 +64,22 @@ def visit_column_name(element, compiler, **kw): format_column_name(compiler, element.newname) ) + @compiles(ColumnDefault, 'oracle') def visit_column_default(element, compiler, **kw): return "%s %s %s" % ( alter_table(compiler, element.table_name, element.schema), alter_column(compiler, element.column_name), "DEFAULT %s" % - format_server_default(compiler, element.default) + format_server_default(compiler, element.default) if element.default is not None else "DEFAULT NULL" ) + def alter_column(compiler, name): return 'MODIFY %s' % format_column_name(compiler, name) + def add_column(compiler, column, **kw): return "ADD %s" % compiler.get_column_specification(column, **kw) diff --git a/alembic/ddl/postgresql.py b/alembic/ddl/postgresql.py index 27f31b0..eab1f4d 100644 --- a/alembic/ddl/postgresql.py +++ b/alembic/ddl/postgresql.py @@ -5,18 +5,19 @@ from .. import compat from .base import compiles, alter_table, format_table_name, RenameTable from .impl import DefaultImpl + class PostgresqlImpl(DefaultImpl): __dialect__ = 'postgresql' transactional_ddl = True def compare_server_default(self, inspector_column, - metadata_column, - rendered_metadata_default, - rendered_inspector_default): + metadata_column, + rendered_metadata_default, + rendered_inspector_default): # don't do defaults for SERIAL columns if metadata_column.primary_key and \ - metadata_column is metadata_column.table._autoincrement_column: + metadata_column is metadata_column.table._autoincrement_column: return False conn_col_default = rendered_inspector_default @@ -26,7 +27,7 @@ class PostgresqlImpl(DefaultImpl): if metadata_column.server_default is not None and \ isinstance(metadata_column.server_default.arg, - compat.string_types) and \ + compat.string_types) and \ not re.match(r"^'.+'$", rendered_metadata_default): rendered_metadata_default = "'%s'" % rendered_metadata_default diff --git a/alembic/ddl/sqlite.py b/alembic/ddl/sqlite.py index 85c829e..1a00be1 100644 --- a/alembic/ddl/sqlite.py +++ b/alembic/ddl/sqlite.py @@ -6,6 +6,7 @@ import re #from .base import AddColumn, alter_table #from sqlalchemy.schema import AddConstraint + class SQLiteImpl(DefaultImpl): __dialect__ = 'sqlite' @@ -19,21 +20,20 @@ class SQLiteImpl(DefaultImpl): # auto-gen constraint and an explicit one if const._create_rule is None: raise NotImplementedError( - "No support for ALTER of constraints in SQLite dialect") + "No support for ALTER of constraints in SQLite dialect") elif const._create_rule(self): util.warn("Skipping unsupported ALTER for " - "creation of implicit constraint") - + "creation of implicit constraint") def drop_constraint(self, const): if const._create_rule is None: raise NotImplementedError( - "No support for ALTER of constraints in SQLite dialect") + "No support for ALTER of constraints in SQLite dialect") def compare_server_default(self, inspector_column, - metadata_column, - rendered_metadata_default, - rendered_inspector_default): + metadata_column, + rendered_metadata_default, + rendered_inspector_default): rendered_metadata_default = re.sub(r"^'|'$", "", rendered_metadata_default) return rendered_inspector_default != repr(rendered_metadata_default) @@ -46,9 +46,9 @@ class SQLiteImpl(DefaultImpl): return tuple(sorted(uq.columns.keys())) conn_unique_sigs = set( - uq_sig(uq) - for uq in conn_unique_constraints - ) + uq_sig(uq) + for uq in conn_unique_constraints + ) for idx in list(metadata_unique_constraints): # SQLite backend can't report on unnamed UNIQUE constraints, @@ -65,18 +65,18 @@ class SQLiteImpl(DefaultImpl): conn_uniques.remove(idx) #@compiles(AddColumn, 'sqlite') -#def visit_add_column(element, compiler, **kw): +# def visit_add_column(element, compiler, **kw): # return "%s %s" % ( # alter_table(compiler, element.table_name, element.schema), # add_column(compiler, element.column, **kw) # ) -#def add_column(compiler, column, **kw): +# def add_column(compiler, column, **kw): # text = "ADD COLUMN %s" % compiler.get_column_specification(column, **kw) -# # need to modify SQLAlchemy so that the CHECK associated with a Boolean -# # or Enum gets placed as part of the column constraints, not the Table -# # see ticket 98 +# need to modify SQLAlchemy so that the CHECK associated with a Boolean +# or Enum gets placed as part of the column constraints, not the Table +# see ticket 98 # for const in column.constraints: # text += compiler.process(AddConstraint(const)) # return text diff --git a/alembic/environment.py b/alembic/environment.py index c3e7a38..405e2f2 100644 --- a/alembic/environment.py +++ b/alembic/environment.py @@ -2,7 +2,9 @@ from .operations import Operations from .migration import MigrationContext from . import util + class EnvironmentContext(object): + """Represent the state made available to an ``env.py`` script. :class:`.EnvironmentContext` is normally instantiated @@ -156,13 +158,13 @@ class EnvironmentContext(object): """ if self._migration_context is not None: return self.script._as_rev_number( - self.get_context()._start_from_rev) + self.get_context()._start_from_rev) elif 'starting_rev' in self.context_opts: return self.script._as_rev_number( - self.context_opts['starting_rev']) + self.context_opts['starting_rev']) else: raise util.CommandError( - "No starting revision argument is available.") + "No starting revision argument is available.") def get_revision_argument(self): """Get the 'destination' revision argument. @@ -179,7 +181,7 @@ class EnvironmentContext(object): """ return self.script._as_rev_number( - self.context_opts['destination_rev']) + self.context_opts['destination_rev']) def get_tag_argument(self): """Return the value passed for the ``--tag`` argument, if any. @@ -247,34 +249,34 @@ class EnvironmentContext(object): value = [] if as_dictionary: value = dict( - arg.split('=', 1) for arg in value - ) + arg.split('=', 1) for arg in value + ) return value def configure(self, - connection=None, - url=None, - dialect_name=None, - transactional_ddl=None, - transaction_per_migration=False, - output_buffer=None, - starting_rev=None, - tag=None, - template_args=None, - target_metadata=None, - include_symbol=None, - include_object=None, - include_schemas=False, - compare_type=False, - compare_server_default=False, - render_item=None, - upgrade_token="upgrades", - downgrade_token="downgrades", - alembic_module_prefix="op.", - sqlalchemy_module_prefix="sa.", - user_module_prefix=None, - **kw - ): + connection=None, + url=None, + dialect_name=None, + transactional_ddl=None, + transaction_per_migration=False, + output_buffer=None, + starting_rev=None, + tag=None, + template_args=None, + target_metadata=None, + include_symbol=None, + include_object=None, + include_schemas=False, + compare_type=False, + compare_server_default=False, + render_item=None, + upgrade_token="upgrades", + downgrade_token="downgrades", + alembic_module_prefix="op.", + sqlalchemy_module_prefix="sa.", + user_module_prefix=None, + **kw + ): """Configure a :class:`.MigrationContext` within this :class:`.EnvironmentContext` which will provide database connectivity and other configuration to a series of @@ -701,7 +703,7 @@ class EnvironmentContext(object): """ self.get_context().execute(sql, - execution_options=execution_options) + execution_options=execution_options) def static_output(self, text): """Emit text directly to the "offline" SQL stream. @@ -714,7 +716,6 @@ class EnvironmentContext(object): """ self.get_context().impl.static_output(text) - def begin_transaction(self): """Return a context manager that will enclose an operation within a "transaction", @@ -761,7 +762,6 @@ class EnvironmentContext(object): return self.get_context().begin_transaction() - def get_context(self): """Return the current :class:`.MigrationContext` object. @@ -789,4 +789,3 @@ class EnvironmentContext(object): def get_impl(self): return self.get_context().impl - diff --git a/alembic/migration.py b/alembic/migration.py index dadf49a..0c91fd1 100644 --- a/alembic/migration.py +++ b/alembic/migration.py @@ -13,7 +13,9 @@ from . import ddl, util log = logging.getLogger(__name__) + class MigrationContext(object): + """Represent the database state made available to a migration script. @@ -58,6 +60,7 @@ class MigrationContext(object): op.alter_column("mytable", "somecolumn", nullable=True) """ + def __init__(self, dialect, connection, opts, environment_context=None): self.environment_context = environment_context self.opts = opts @@ -68,7 +71,7 @@ class MigrationContext(object): transactional_ddl = opts.get("transactional_ddl") self._transaction_per_migration = opts.get( - "transaction_per_migration", False) + "transaction_per_migration", False) if as_sql: self.connection = self._stdout_connection(connection) @@ -88,8 +91,8 @@ class MigrationContext(object): self._user_compare_type = opts.get('compare_type', False) self._user_compare_server_default = opts.get( - 'compare_server_default', - False) + 'compare_server_default', + False) version_table = opts.get('version_table', 'alembic_version') version_table_schema = opts.get('version_table_schema', None) self._version = Table( @@ -99,26 +102,26 @@ class MigrationContext(object): self._start_from_rev = opts.get("starting_rev") self.impl = ddl.DefaultImpl.get_by_dialect(dialect)( - dialect, self.connection, self.as_sql, - transactional_ddl, - self.output_buffer, - opts - ) + dialect, self.connection, self.as_sql, + transactional_ddl, + self.output_buffer, + opts + ) log.info("Context impl %s.", self.impl.__class__.__name__) if self.as_sql: log.info("Generating static SQL") log.info("Will assume %s DDL.", - "transactional" if self.impl.transactional_ddl - else "non-transactional") + "transactional" if self.impl.transactional_ddl + else "non-transactional") @classmethod def configure(cls, - connection=None, - url=None, - dialect_name=None, - environment_context=None, - opts=None, - ): + connection=None, + url=None, + dialect_name=None, + environment_context=None, + opts=None, + ): """Create a new :class:`.MigrationContext`. This is a factory method usually called @@ -155,7 +158,6 @@ class MigrationContext(object): return MigrationContext(dialect, connection, opts, environment_context) - def begin_transaction(self, _per_migration=False): transaction_now = _per_migration == self._transaction_per_migration @@ -209,12 +211,12 @@ class MigrationContext(object): self.impl._exec(self._version.delete()) elif old is None: self.impl._exec(self._version.insert(). - values(version_num=literal_column("'%s'" % new)) - ) + values(version_num=literal_column("'%s'" % new)) + ) else: self.impl._exec(self._version.update(). - values(version_num=literal_column("'%s'" % new)) - ) + values(version_num=literal_column("'%s'" % new)) + ) def run_migrations(self, **kw): """Run the migration scripts established for this :class:`.MigrationContext`, @@ -239,12 +241,12 @@ class MigrationContext(object): """ current_rev = rev = False stamp_per_migration = not self.impl.transactional_ddl or \ - self._transaction_per_migration + self._transaction_per_migration self.impl.start_migrations() for change, prev_rev, rev, doc in self._migrations_fn( - self.get_current_revision(), - self): + self.get_current_revision(), + self): with self.begin_transaction(_per_migration=True): if current_rev is False: current_rev = prev_rev @@ -252,14 +254,14 @@ class MigrationContext(object): self._version.create(self.connection) if doc: log.info("Running %s %s -> %s, %s", change.__name__, prev_rev, - rev, doc) + rev, doc) else: log.info("Running %s %s -> %s", change.__name__, prev_rev, rev) if self.as_sql: self.impl.static_output( - "-- Running %s %s -> %s" % - (change.__name__, prev_rev, rev) - ) + "-- Running %s %s -> %s" % + (change.__name__, prev_rev, rev) + ) change(**kw) if stamp_per_migration: self._update_current_rev(prev_rev, rev) @@ -288,7 +290,7 @@ class MigrationContext(object): self.impl._exec(construct) return create_engine("%s://" % self.dialect.name, - strategy="mock", executor=dump) + strategy="mock", executor=dump) @property def bind(self): @@ -338,32 +340,31 @@ class MigrationContext(object): return user_value return self.impl.compare_type( - inspector_column, - metadata_column) + inspector_column, + metadata_column) def _compare_server_default(self, inspector_column, - metadata_column, - rendered_metadata_default, - rendered_column_default): + metadata_column, + rendered_metadata_default, + rendered_column_default): if self._user_compare_server_default is False: return False if callable(self._user_compare_server_default): user_value = self._user_compare_server_default( - self, - inspector_column, - metadata_column, - rendered_column_default, - metadata_column.server_default, - rendered_metadata_default + self, + inspector_column, + metadata_column, + rendered_column_default, + metadata_column.server_default, + rendered_metadata_default ) if user_value is not None: return user_value return self.impl.compare_server_default( - inspector_column, - metadata_column, - rendered_metadata_default, - rendered_column_default) - + inspector_column, + metadata_column, + rendered_metadata_default, + rendered_column_default) diff --git a/alembic/operations.py b/alembic/operations.py index d028688..a1f3dee 100644 --- a/alembic/operations.py +++ b/alembic/operations.py @@ -14,7 +14,9 @@ try: except: conv = None + class Operations(object): + """Define high level migration operations. Each operation corresponds to some schema migration operation, @@ -39,6 +41,7 @@ class Operations(object): op.alter_column("t", "c", nullable=True) """ + def __init__(self, migration_context): """Construct a new :class:`.Operations` @@ -58,57 +61,56 @@ class Operations(object): yield op _remove_proxy() - def _primary_key_constraint(self, name, table_name, cols, schema=None): m = self._metadata() columns = [sa_schema.Column(n, NULLTYPE) for n in cols] t1 = sa_schema.Table(table_name, m, - *columns, - schema=schema) + *columns, + schema=schema) p = sa_schema.PrimaryKeyConstraint(*columns, name=name) t1.append_constraint(p) return p def _foreign_key_constraint(self, name, source, referent, - local_cols, remote_cols, - onupdate=None, ondelete=None, - deferrable=None, source_schema=None, - referent_schema=None, initially=None, - match=None, **dialect_kw): + local_cols, remote_cols, + onupdate=None, ondelete=None, + deferrable=None, source_schema=None, + referent_schema=None, initially=None, + match=None, **dialect_kw): m = self._metadata() if source == referent: t1_cols = local_cols + remote_cols else: t1_cols = local_cols sa_schema.Table(referent, m, - *[sa_schema.Column(n, NULLTYPE) for n in remote_cols], - schema=referent_schema) + *[sa_schema.Column(n, NULLTYPE) for n in remote_cols], + schema=referent_schema) t1 = sa_schema.Table(source, m, - *[sa_schema.Column(n, NULLTYPE) for n in t1_cols], - schema=source_schema) + *[sa_schema.Column(n, NULLTYPE) for n in t1_cols], + schema=source_schema) tname = "%s.%s" % (referent_schema, referent) if referent_schema \ else referent f = sa_schema.ForeignKeyConstraint(local_cols, - ["%s.%s" % (tname, n) + ["%s.%s" % (tname, n) for n in remote_cols], - name=name, - onupdate=onupdate, - ondelete=ondelete, - deferrable=deferrable, - initially=initially, - match=match, - **dialect_kw - ) + name=name, + onupdate=onupdate, + ondelete=ondelete, + deferrable=deferrable, + initially=initially, + match=match, + **dialect_kw + ) t1.append_constraint(f) return f def _unique_constraint(self, name, source, local_cols, schema=None, **kw): t = sa_schema.Table(source, self._metadata(), - *[sa_schema.Column(n, NULLTYPE) for n in local_cols], - schema=schema) + *[sa_schema.Column(n, NULLTYPE) for n in local_cols], + schema=schema) kw['name'] = name uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw) # TODO: need event tests to ensure the event @@ -118,7 +120,7 @@ class Operations(object): def _check_constraint(self, name, source, condition, schema=None, **kw): t = sa_schema.Table(source, self._metadata(), - sa_schema.Column('x', Integer), schema=schema) + sa_schema.Column('x', Integer), schema=schema) ck = sa_schema.CheckConstraint(condition, name=name, **kw) t.append_constraint(ck) return ck @@ -201,17 +203,17 @@ class Operations(object): @util._with_legacy_names([('name', 'new_column_name')]) def alter_column(self, table_name, column_name, - nullable=None, - server_default=False, - new_column_name=None, - type_=None, - autoincrement=None, - existing_type=None, - existing_server_default=False, - existing_nullable=None, - existing_autoincrement=None, - schema=None - ): + nullable=None, + server_default=False, + new_column_name=None, + type_=None, + autoincrement=None, + existing_type=None, + existing_server_default=False, + existing_nullable=None, + existing_autoincrement=None, + schema=None + ): """Issue an "alter column" instruction using the current migration context. @@ -291,9 +293,10 @@ class Operations(object): """ compiler = self.impl.dialect.statement_compiler( - self.impl.dialect, - None - ) + self.impl.dialect, + None + ) + def _count_constraint(constraint): return not isinstance(constraint, sa_schema.PrimaryKeyConstraint) and \ (not constraint._create_rule or @@ -301,31 +304,31 @@ class Operations(object): if existing_type and type_: t = self._table(table_name, - sa_schema.Column(column_name, existing_type), - schema=schema - ) + sa_schema.Column(column_name, existing_type), + schema=schema + ) for constraint in t.constraints: if _count_constraint(constraint): self.impl.drop_constraint(constraint) self.impl.alter_column(table_name, column_name, - nullable=nullable, - server_default=server_default, - name=new_column_name, - type_=type_, - schema=schema, - autoincrement=autoincrement, - existing_type=existing_type, - existing_server_default=existing_server_default, - existing_nullable=existing_nullable, - existing_autoincrement=existing_autoincrement - ) + nullable=nullable, + server_default=server_default, + name=new_column_name, + type_=type_, + schema=schema, + autoincrement=autoincrement, + existing_type=existing_type, + existing_server_default=existing_server_default, + existing_nullable=existing_nullable, + existing_autoincrement=existing_autoincrement + ) if type_: t = self._table(table_name, - sa_schema.Column(column_name, type_), - schema=schema - ) + sa_schema.Column(column_name, type_), + schema=schema + ) for constraint in t.constraints: if _count_constraint(constraint): self.impl.add_constraint(constraint) @@ -374,7 +377,7 @@ class Operations(object): return conv(name) else: raise NotImplementedError( - "op.f() feature requires SQLAlchemy 0.9.4 or greater.") + "op.f() feature requires SQLAlchemy 0.9.4 or greater.") def add_column(self, table_name, column, schema=None): """Issue an "add column" instruction using the current @@ -481,7 +484,6 @@ class Operations(object): **kw ) - def create_primary_key(self, name, table_name, cols, schema=None): """Issue a "create primary key" instruction using the current migration context. @@ -518,10 +520,9 @@ class Operations(object): """ self.impl.add_constraint( - self._primary_key_constraint(name, table_name, cols, - schema) - ) - + self._primary_key_constraint(name, table_name, cols, + schema) + ) def create_foreign_key(self, name, source, referent, local_cols, remote_cols, onupdate=None, ondelete=None, @@ -573,13 +574,13 @@ class Operations(object): """ self.impl.add_constraint( - self._foreign_key_constraint(name, source, referent, - local_cols, remote_cols, - onupdate=onupdate, ondelete=ondelete, - deferrable=deferrable, source_schema=source_schema, - referent_schema=referent_schema, - initially=initially, match=match, **dialect_kw) - ) + self._foreign_key_constraint(name, source, referent, + local_cols, remote_cols, + onupdate=onupdate, ondelete=ondelete, + deferrable=deferrable, source_schema=source_schema, + referent_schema=referent_schema, + initially=initially, match=match, **dialect_kw) + ) def create_unique_constraint(self, name, source, local_cols, schema=None, **kw): @@ -621,9 +622,9 @@ class Operations(object): """ self.impl.add_constraint( - self._unique_constraint(name, source, local_cols, - schema=schema, **kw) - ) + self._unique_constraint(name, source, local_cols, + schema=schema, **kw) + ) def create_check_constraint(self, name, source, condition, schema=None, **kw): @@ -841,7 +842,7 @@ class Operations(object): t = self._table(table_name, schema=schema) types = { 'foreignkey': lambda name: sa_schema.ForeignKeyConstraint( - [], [], name=name), + [], [], name=name), 'primary': sa_schema.PrimaryKeyConstraint, 'unique': sa_schema.UniqueConstraint, 'check': lambda name: sa_schema.CheckConstraint("", name=name), @@ -851,7 +852,7 @@ class Operations(object): const = types[type_] except KeyError: raise TypeError("'type' can be one of %s" % - ", ".join(sorted(repr(x) for x in types))) + ", ".join(sorted(repr(x) for x in types))) const = const(name=name) t.append_constraint(const) @@ -1038,7 +1039,7 @@ class Operations(object): :meth:`sqlalchemy.engine.Connection.execution_options`. """ self.migration_context.impl.execute(sql, - execution_options=execution_options) + execution_options=execution_options) def get_bind(self): """Return the current 'bind'. @@ -1051,4 +1052,3 @@ class Operations(object): """ return self.migration_context.impl.bind - diff --git a/alembic/script.py b/alembic/script.py index ed44f71..a97fc9c 100644 --- a/alembic/script.py +++ b/alembic/script.py @@ -12,7 +12,9 @@ _slug_re = re.compile(r'\w+') _default_file_template = "%(rev)s_%(slug)s" _relative_destination = re.compile(r'(?:\+|-)\d+') + class ScriptDirectory(object): + """Provides operations upon an Alembic script directory. This object is useful to get information as to current revisions, @@ -31,9 +33,10 @@ class ScriptDirectory(object): """ + def __init__(self, dir, file_template=_default_file_template, - truncate_slug_length=40, - sourceless=False): + truncate_slug_length=40, + sourceless=False): self.dir = dir self.versions = os.path.join(self.dir, 'versions') self.file_template = file_template @@ -42,8 +45,8 @@ class ScriptDirectory(object): if not os.access(dir, os.F_OK): raise util.CommandError("Path doesn't exist: %r. Please use " - "the 'init' command to create a new " - "scripts folder." % dir) + "the 'init' command to create a new " + "scripts folder." % dir) @classmethod def from_config(cls, config): @@ -62,13 +65,13 @@ class ScriptDirectory(object): if truncate_slug_length is not None: truncate_slug_length = int(truncate_slug_length) return ScriptDirectory( - util.coerce_resource_to_filename(script_location), - file_template=config.get_main_option( - 'file_template', - _default_file_template), - truncate_slug_length=truncate_slug_length, - sourceless=config.get_main_option("sourceless") == "true" - ) + util.coerce_resource_to_filename(script_location), + file_template=config.get_main_option( + 'file_template', + _default_file_template), + truncate_slug_length=truncate_slug_length, + sourceless=config.get_main_option("sourceless") == "true" + ) def walk_revisions(self, base="base", head="head"): """Iterate through all revisions. @@ -108,11 +111,11 @@ class ScriptDirectory(object): raise util.CommandError("No such revision '%s'" % id_) elif len(revs) > 1: raise util.CommandError( - "Multiple revisions start " - "with '%s', %s..." % ( - id_, - ", ".join("'%s'" % r for r in revs[0:3]) - )) + "Multiple revisions start " + "with '%s', %s..." % ( + id_, + ", ".join("'%s'" % r for r in revs[0:3]) + )) else: return self._revision_map[revs[0]] @@ -148,7 +151,7 @@ class ScriptDirectory(object): revs = revs[-relative:] if len(revs) != abs(relative): raise util.CommandError("Relative revision %s didn't " - "produce %d migrations" % (upper, abs(relative))) + "produce %d migrations" % (upper, abs(relative))) return iter(revs) elif lower is not None and _relative_destination.match(lower): relative = int(lower) @@ -156,7 +159,7 @@ class ScriptDirectory(object): revs = revs[0:-relative] if len(revs) != abs(relative): raise util.CommandError("Relative revision %s didn't " - "produce %d migrations" % (lower, abs(relative))) + "produce %d migrations" % (lower, abs(relative))) return iter(revs) else: return self._iterate_revisions(upper, lower) @@ -165,12 +168,12 @@ class ScriptDirectory(object): lower = self.get_revision(lower) upper = self.get_revision(upper) orig = lower.revision if lower else 'base', \ - upper.revision if upper else 'base' + upper.revision if upper else 'base' script = upper while script != lower: if script is None and lower is not None: raise util.CommandError( - "Revision %s is not an ancestor of %s" % orig) + "Revision %s is not an ancestor of %s" % orig) yield script downrev = script.down_revision script = self._revision_map[downrev] @@ -181,7 +184,7 @@ class ScriptDirectory(object): (script.module.upgrade, script.down_revision, script.revision, script.doc) for script in reversed(list(revs)) - ] + ] def _downgrade_revs(self, destination, current_rev): revs = self.iterate_revisions(current_rev, destination) @@ -189,7 +192,7 @@ class ScriptDirectory(object): (script.module.downgrade, script.revision, script.down_revision, script.doc) for script in revs - ] + ] def run_env(self): """Run the script environment. @@ -216,14 +219,14 @@ class ScriptDirectory(object): continue if script.revision in map_: util.warn("Revision %s is present more than once" % - script.revision) + script.revision) map_[script.revision] = script for rev in map_.values(): if rev.down_revision is None: continue if rev.down_revision not in map_: util.warn("Revision %s referenced from %s is not present" - % (rev.down_revision, rev)) + % (rev.down_revision, rev)) rev.down_revision = None else: map_[rev.down_revision].add_nextrev(rev.revision) @@ -260,10 +263,10 @@ class ScriptDirectory(object): current_heads = self.get_heads() if len(current_heads) > 1: raise util.CommandError('Only a single head is supported. The ' - 'script directory has multiple heads (due to branching), which ' - 'must be resolved by manually editing the revision files to ' - 'form a linear sequence. Run `alembic branches` to see the ' - 'divergence(s).') + 'script directory has multiple heads (due to branching), which ' + 'must be resolved by manually editing the revision files to ' + 'form a linear sequence. Run `alembic branches` to see the ' + 'divergence(s).') if current_heads: return current_heads[0] @@ -303,18 +306,18 @@ class ScriptDirectory(object): """ for script in self._revision_map.values(): if script and script.down_revision is None \ - and script.revision in self._revision_map: + and script.revision in self._revision_map: return script.revision else: return None def _generate_template(self, src, dest, **kw): util.status("Generating %s" % os.path.abspath(dest), - util.template_to_file, - src, - dest, - **kw - ) + util.template_to_file, + src, + dest, + **kw + ) def _copy_file(self, src, dest): util.status("Generating %s" % os.path.abspath(dest), @@ -357,13 +360,14 @@ class ScriptDirectory(object): self._revision_map[script.revision] = script if script.down_revision: self._revision_map[script.down_revision].\ - add_nextrev(script.revision) + add_nextrev(script.revision) return script else: return None class Script(object): + """Represent a single revision file in a ``versions/`` directory. The :class:`.Script` instance is returned by methods @@ -455,11 +459,11 @@ class Script(object): def __str__(self): return "%s -> %s%s%s, %s" % ( - self.down_revision, - self.revision, - " (head)" if self.is_head else "", - " (branchpoint)" if self.is_branch_point else "", - self.doc) + self.down_revision, + self.revision, + " (head)" if self.is_head else "", + " (branchpoint)" if self.is_branch_point else "", + self.doc) @classmethod def _from_path(cls, scriptdir, path): @@ -502,11 +506,11 @@ class Script(object): m = _legacy_rev.match(filename) if not m: raise util.CommandError( - "Could not determine revision id from filename %s. " - "Be sure the 'revision' variable is " - "declared inside the script (please see 'Upgrading " - "from Alembic 0.1 to 0.2' in the documentation)." - % filename) + "Could not determine revision id from filename %s. " + "Be sure the 'revision' variable is " + "declared inside the script (please see 'Upgrading " + "from Alembic 0.1 to 0.2' in the documentation)." + % filename) else: revision = m.group(1) else: diff --git a/alembic/templates/generic/env.py b/alembic/templates/generic/env.py index 712b616..fccd445 100644 --- a/alembic/templates/generic/env.py +++ b/alembic/templates/generic/env.py @@ -22,6 +22,7 @@ target_metadata = None # my_important_option = config.get_main_option("my_important_option") # ... etc. + def run_migrations_offline(): """Run migrations in 'offline' mode. @@ -40,6 +41,7 @@ def run_migrations_offline(): with context.begin_transaction(): context.run_migrations() + def run_migrations_online(): """Run migrations in 'online' mode. @@ -48,15 +50,15 @@ def run_migrations_online(): """ engine = engine_from_config( - config.get_section(config.config_ini_section), - prefix='sqlalchemy.', - poolclass=pool.NullPool) + config.get_section(config.config_ini_section), + prefix='sqlalchemy.', + poolclass=pool.NullPool) connection = engine.connect() context.configure( - connection=connection, - target_metadata=target_metadata - ) + connection=connection, + target_metadata=target_metadata + ) try: with context.begin_transaction(): @@ -68,4 +70,3 @@ if context.is_offline_mode(): run_migrations_offline() else: run_migrations_online() - diff --git a/alembic/templates/multidb/env.py b/alembic/templates/multidb/env.py index e3511de..ab37199 100644 --- a/alembic/templates/multidb/env.py +++ b/alembic/templates/multidb/env.py @@ -39,6 +39,7 @@ target_metadata = {} # my_important_option = config.get_main_option("my_important_option") # ... etc. + def run_migrations_offline(): """Run migrations in 'offline' mode. @@ -58,7 +59,7 @@ def run_migrations_offline(): for name in re.split(r',\s*', db_names): engines[name] = rec = {} rec['url'] = context.config.get_section_option(name, - "sqlalchemy.url") + "sqlalchemy.url") for name, rec in engines.items(): logger.info("Migrating database %s" % name) @@ -66,10 +67,11 @@ def run_migrations_offline(): logger.info("Writing output to %s" % file_) with open(file_, 'w') as buffer: context.configure(url=rec['url'], output_buffer=buffer, - target_metadata=target_metadata.get(name)) + target_metadata=target_metadata.get(name)) with context.begin_transaction(): context.run_migrations(engine_name=name) + def run_migrations_online(): """Run migrations in 'online' mode. @@ -85,9 +87,9 @@ def run_migrations_online(): for name in re.split(r',\s*', db_names): engines[name] = rec = {} rec['engine'] = engine_from_config( - context.config.get_section(name), - prefix='sqlalchemy.', - poolclass=pool.NullPool) + context.config.get_section(name), + prefix='sqlalchemy.', + poolclass=pool.NullPool) for name, rec in engines.items(): engine = rec['engine'] @@ -102,11 +104,11 @@ def run_migrations_online(): for name, rec in engines.items(): logger.info("Migrating database %s" % name) context.configure( - connection=rec['connection'], - upgrade_token="%s_upgrades" % name, - downgrade_token="%s_downgrades" % name, - target_metadata=target_metadata.get(name) - ) + connection=rec['connection'], + upgrade_token="%s_upgrades" % name, + downgrade_token="%s_downgrades" % name, + target_metadata=target_metadata.get(name) + ) context.run_migrations(engine_name=name) if USE_TWOPHASE: diff --git a/alembic/templates/pylons/env.py b/alembic/templates/pylons/env.py index 36c3fca..3329428 100644 --- a/alembic/templates/pylons/env.py +++ b/alembic/templates/pylons/env.py @@ -46,7 +46,7 @@ def run_migrations_offline(): """ context.configure( - url=meta.engine.url, target_metadata=target_metadata) + url=meta.engine.url, target_metadata=target_metadata) with context.begin_transaction(): context.run_migrations() @@ -70,9 +70,9 @@ def run_migrations_online(): ) context.configure( - connection=connection, - target_metadata=target_metadata - ) + connection=connection, + target_metadata=target_metadata + ) try: with context.begin_transaction(): diff --git a/alembic/util.py b/alembic/util.py index 8c02d57..e0d62eb 100644 --- a/alembic/util.py +++ b/alembic/util.py @@ -12,9 +12,11 @@ from sqlalchemy import __version__ from .compat import callable, exec_, load_module_py, load_module_pyc, binary_type + class CommandError(Exception): pass + def _safe_int(value): try: return int(value) @@ -28,7 +30,7 @@ sqla_092 = _vers >= (0, 9, 2) sqla_094 = _vers >= (0, 9, 4) if not sqla_07: raise CommandError( - "SQLAlchemy 0.7.3 or greater is required. ") + "SQLAlchemy 0.7.3 or greater is required. ") from sqlalchemy.util import format_argspec_plus, update_wrapper from sqlalchemy.util.compat import inspect_getfullargspec @@ -41,7 +43,7 @@ try: import termios import struct ioctl = fcntl.ioctl(0, termios.TIOCGWINSZ, - struct.pack('HHHH', 0, 0, 0, 0)) + struct.pack('HHHH', 0, 0, 0, 0)) _h, TERMWIDTH, _hp, _wp = struct.unpack('HHHH', ioctl) if TERMWIDTH <= 0: # can occur if running in emacs pseudo-tty TERMWIDTH = None @@ -55,6 +57,7 @@ def template_to_file(template_file, dest, **kw): Template(filename=template_file).render(**kw) ) + def create_module_class_proxy(cls, globals_, locals_): """Create module level proxy functions for the methods on a given class. @@ -97,18 +100,18 @@ def create_module_class_proxy(cls, globals_, locals_): defaulted_vals = () apply_kw = inspect.formatargspec( - name_args, spec[1], spec[2], - defaulted_vals, - formatvalue=lambda x: '=' + x) + name_args, spec[1], spec[2], + defaulted_vals, + formatvalue=lambda x: '=' + x) def _name_error(name): raise NameError( - "Can't invoke function '%s', as the proxy object has "\ - "not yet been " - "established for the Alembic '%s' class. " - "Try placing this code inside a callable." % ( - name, cls.__name__ - )) + "Can't invoke function '%s', as the proxy object has " + "not yet been " + "established for the Alembic '%s' class. " + "Try placing this code inside a callable." % ( + name, cls.__name__ + )) globals_['_name_error'] = _name_error func_text = textwrap.dedent("""\ @@ -137,6 +140,7 @@ def create_module_class_proxy(cls, globals_, locals_): else: attr_names.add(methname) + def write_outstream(stream, *text): encoding = getattr(stream, 'encoding', 'ascii') or 'ascii' for t in text: @@ -151,6 +155,7 @@ def write_outstream(stream, *text): # as the exception is "ignored" (noisily) in TextIOWrapper. break + def coerce_resource_to_filename(fname): """Interpret a filename as either a filesystem location or as a package resource. @@ -163,6 +168,7 @@ def coerce_resource_to_filename(fname): fname = pkg_resources.resource_filename(*fname.split(':')) return fname + def status(_statmsg, fn, *arg, **kw): msg(_statmsg + " ...", False) try: @@ -173,24 +179,29 @@ def status(_statmsg, fn, *arg, **kw): write_outstream(sys.stdout, " FAILED\n") raise + def err(message): log.error(message) msg("FAILED: %s" % message) sys.exit(-1) + def obfuscate_url_pw(u): u = url.make_url(u) if u.password: u.password = 'XXXXX' return str(u) + def asbool(value): return value is not None and \ value.lower() == 'true' + def warn(msg): warnings.warn(msg) + def msg(msg, newline=True): if TERMWIDTH is None: write_outstream(sys.stdout, msg) @@ -204,6 +215,7 @@ def msg(msg, newline=True): write_outstream(sys.stdout, " ", line, "\n") write_outstream(sys.stdout, " ", lines[-1], ("\n" if newline else "")) + def load_python_file(dir_, filename): """Load a file from the given path as a Python module.""" @@ -223,6 +235,7 @@ def load_python_file(dir_, filename): del sys.modules[module_id] return module + def simple_pyc_file_from_path(path): """Given a python source path, return the so-called "sourceless" .pyc or .pyo path. @@ -238,6 +251,7 @@ def simple_pyc_file_from_path(path): else: return path + "c" # e.g. .pyc + def pyc_file_from_path(path): """Given a python source path, locate the .pyc. @@ -253,11 +267,14 @@ def pyc_file_from_path(path): else: return simple_pyc_file_from_path(path) + def rev_id(): val = int(uuid.uuid4()) % 100000000000000 return hex(val)[2:-1] + class memoized_property(object): + """A read-only @property that is only evaluated once.""" def __init__(self, fget, doc=None): @@ -278,7 +295,7 @@ class immutabledict(dict): raise TypeError("%s object is immutable" % self.__class__.__name__) __delitem__ = __setitem__ = __setattr__ = \ - clear = pop = popitem = setdefault = \ + clear = pop = popitem = setdefault = \ update = _immutable def __new__(cls, *args): @@ -332,7 +349,7 @@ def _with_legacy_names(translations): return fn(*arg, **kw) code = 'lambda %(args)s: %(target)s(%(apply_kw)s)' % ( - metadata) + metadata) decorated = eval(code, {"target": go}) decorated.__defaults__ = getattr(fn, '__func__', fn).__defaults__ update_wrapper(decorated, fn) @@ -346,6 +363,3 @@ def _with_legacy_names(translations): return decorated return decorate - - - @@ -34,14 +34,14 @@ setup(name='alembic', description="A database migration tool for SQLAlchemy.", long_description=open(readme).read(), classifiers=[ - 'Development Status :: 4 - Beta', - 'Environment :: Console', - 'Intended Audience :: Developers', - 'Programming Language :: Python', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: Implementation :: CPython', - 'Programming Language :: Python :: Implementation :: PyPy', - 'Topic :: Database :: Front-Ends', + 'Development Status :: 4 - Beta', + 'Environment :: Console', + 'Intended Audience :: Developers', + 'Programming Language :: Python', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: Implementation :: CPython', + 'Programming Language :: Python :: Implementation :: PyPy', + 'Topic :: Database :: Front-Ends', ], keywords='SQLAlchemy migrations', author='Mike Bayer', @@ -50,11 +50,11 @@ setup(name='alembic', license='MIT', packages=find_packages('.', exclude=['examples*', 'test*']), include_package_data=True, - tests_require = ['nose >= 0.11', 'mock'], - test_suite = "nose.collector", + tests_require=['nose >= 0.11', 'mock'], + test_suite="nose.collector", zip_safe=False, install_requires=requires, - entry_points = { - 'console_scripts': [ 'alembic = alembic.config:main' ], + entry_points={ + 'console_scripts': ['alembic = alembic.config:main'], } -) + ) diff --git a/tests/__init__.py b/tests/__init__.py index ba8c0eb..9b5944f 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -37,8 +37,8 @@ else: import mock except ImportError: raise ImportError( - "Alembic's test suite requires the " - "'mock' library as of 0.6.1.") + "Alembic's test suite requires the " + "'mock' library as of 0.6.1.") def sqlite_db(): @@ -48,14 +48,18 @@ def sqlite_db(): dir_ = os.path.join(staging_directory, 'scripts') return create_engine('sqlite:///%s/foo.db' % dir_) + def capture_db(): buf = [] + def dump(sql, *multiparams, **params): buf.append(str(sql.compile(dialect=engine.dialect))) engine = create_engine("postgresql://", strategy="mock", executor=dump) return engine, buf _engs = {} + + def db_for_dialect(name): if name in _engs: return _engs[name] @@ -82,18 +86,21 @@ def requires_08(fn, *arg, **kw): raise SkipTest("SQLAlchemy 0.8.0b2 or greater required") return fn(*arg, **kw) + @decorator def requires_09(fn, *arg, **kw): if not util.sqla_09: raise SkipTest("SQLAlchemy 0.9 or greater required") return fn(*arg, **kw) + @decorator def requires_092(fn, *arg, **kw): if not util.sqla_092: raise SkipTest("SQLAlchemy 0.9.2 or greater required") return fn(*arg, **kw) + @decorator def requires_094(fn, *arg, **kw): if not util.sqla_094: @@ -101,6 +108,8 @@ def requires_094(fn, *arg, **kw): return fn(*arg, **kw) _dialects = {} + + def _get_dialect(name): if name is None or name == 'default': return default.DefaultDialect() @@ -114,14 +123,16 @@ def _get_dialect(name): d.implicit_returning = True return d + def assert_compiled(element, assert_string, dialect=None): dialect = _get_dialect(dialect) eq_( - text_type(element.compile(dialect=dialect)).\ - replace("\n", "").replace("\t", ""), + text_type(element.compile(dialect=dialect)). + replace("\n", "").replace("\t", ""), assert_string.replace("\n", "").replace("\t", "") ) + @contextmanager def capture_context_buffer(**kw): if kw.pop('bytes_io', False): @@ -130,10 +141,11 @@ def capture_context_buffer(**kw): buf = io.StringIO() kw.update({ - 'dialect_name': "sqlite", - 'output_buffer': buf + 'dialect_name': "sqlite", + 'output_buffer': buf }) conf = EnvironmentContext.configure + def configure(*arg, **opt): opt.update(**kw) return conf(*arg, **opt) @@ -141,6 +153,7 @@ def capture_context_buffer(**kw): with mock.patch.object(EnvironmentContext, "configure", configure): yield buf + def eq_ignore_whitespace(a, b, msg=None): a = re.sub(r'^\s+?|\n', "", a) a = re.sub(r' {2,}', " ", a) @@ -148,18 +161,22 @@ def eq_ignore_whitespace(a, b, msg=None): b = re.sub(r' {2,}', " ", b) assert a == b, msg or "%r != %r" % (a, b) + def eq_(a, b, msg=None): """Assert a == b, with repr messaging on failure.""" assert a == b, msg or "%r != %r" % (a, b) + def ne_(a, b, msg=None): """Assert a != b, with repr messaging on failure.""" assert a != b, msg or "%r == %r" % (a, b) + def is_(a, b, msg=None): """Assert a is b, with repr messaging on failure.""" assert a is b, msg or "%r is not %r" % (a, b) + def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): try: callable_(*args, **kwargs) @@ -168,9 +185,12 @@ def assert_raises_message(except_cls, msg, callable_, *args, **kwargs): assert re.search(msg, str(e)), "%r !~ %s" % (msg, e) print(text_type(e)) + def op_fixture(dialect='default', as_sql=False, naming_convention=None): impl = _impls[dialect] + class Impl(impl): + def __init__(self, dialect, as_sql): self.assertion = [] self.dialect = dialect @@ -179,6 +199,7 @@ def op_fixture(dialect='default', as_sql=False, naming_convention=None): # be more like a real connection # as tests get more involved self.connection = None + def _exec(self, construct, *args, **kw): if isinstance(construct, string_types): construct = text(construct) @@ -193,11 +214,12 @@ def op_fixture(dialect='default', as_sql=False, naming_convention=None): if naming_convention: if not util.sqla_092: raise SkipTest( - "naming_convention feature requires " - "sqla 0.9.2 or greater") + "naming_convention feature requires " + "sqla 0.9.2 or greater") opts['target_metadata'] = MetaData(naming_convention=naming_convention) class ctx(MigrationContext): + def __init__(self, dialect='default', as_sql=False): self.dialect = _get_dialect(dialect) self.impl = Impl(self.dialect, as_sql) @@ -222,12 +244,14 @@ def op_fixture(dialect='default', as_sql=False, naming_convention=None): alembic.op._proxy = Operations(context) return context + def script_file_fixture(txt): dir_ = os.path.join(staging_directory, 'scripts') path = os.path.join(dir_, "script.py.mako") with open(path, 'w') as f: f.write(txt) + def env_file_fixture(txt): dir_ = os.path.join(staging_directory, 'scripts') txt = """ @@ -244,6 +268,7 @@ config = context.config with open(path, 'w') as f: f.write(txt) + def _sqlite_testing_config(sourceless=False): dir_ = os.path.join(staging_directory, 'scripts') return _write_config_file(""" @@ -313,12 +338,14 @@ datefmt = %%H:%%M:%%S """ % (dir_, dialect, directives)) + def _write_config_file(text): cfg = _testing_config() with open(cfg.config_file_name, 'w') as f: f.write(text) return cfg + def _testing_config(): from alembic.config import Config if not os.access(staging_directory, os.F_OK): @@ -350,6 +377,7 @@ def staging_env(create=True, template="generic", sourceless=False): sc = script.ScriptDirectory.from_config(cfg) return sc + def clear_staging_env(): shutil.rmtree(staging_directory, True) @@ -370,13 +398,14 @@ def write_script(scriptdir, rev_id, content, encoding='ascii', sourceless=False) old = scriptdir._revision_map[script.revision] if old.down_revision != script.down_revision: raise Exception("Can't change down_revision " - "on a refresh operation.") + "on a refresh operation.") scriptdir._revision_map[script.revision] = script script.nextrev = old.nextrev if sourceless: make_sourceless(path) + def make_sourceless(path): # note that if -O is set, you'd see pyo files here, # the pyc util function looks at sys.flags.optimize to handle this @@ -391,6 +420,7 @@ def make_sourceless(path): shutil.copyfile(pyc_path, simple_pyc_path) os.unlink(path) + def three_rev_fixture(cfg): a = util.rev_id() b = util.rev_id() diff --git a/tests/test_autogen_indexes.py b/tests/test_autogen_indexes.py index 2f7a4a1..0885477 100644 --- a/tests/test_autogen_indexes.py +++ b/tests/test_autogen_indexes.py @@ -14,6 +14,7 @@ py3k = sys.version_info >= (3, ) from .test_autogenerate import AutogenFixtureTest + class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase): reports_unique_constraints = True @@ -22,17 +23,17 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase): m2 = MetaData() Table('user', m1, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False, index=True), - Column('a1', String(10), server_default="x") - ) + Column('id', Integer, primary_key=True), + Column('name', String(50), nullable=False, index=True), + Column('a1', String(10), server_default="x") + ) Table('user', m2, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', String(10), server_default="x"), - UniqueConstraint("name", name="uq_user_name") - ) + Column('id', Integer, primary_key=True), + Column('name', String(50), nullable=False), + Column('a1', String(10), server_default="x"), + UniqueConstraint("name", name="uq_user_name") + ) diffs = self._fixture(m1, m2) @@ -46,21 +47,20 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase): eq_(diffs[0][0], "remove_index") eq_(diffs[0][1].name, "ix_user_name") - def test_add_unique_constraint(self): m1 = MetaData() m2 = MetaData() Table('address', m1, - Column('id', Integer, primary_key=True), - Column('email_address', String(100), nullable=False), - Column('qpr', String(10), index=True), - ) + Column('id', Integer, primary_key=True), + Column('email_address', String(100), nullable=False), + Column('qpr', String(10), index=True), + ) Table('address', m2, - Column('id', Integer, primary_key=True), - Column('email_address', String(100), nullable=False), - Column('qpr', String(10), index=True), - UniqueConstraint("email_address", name="uq_email_address") - ) + Column('id', Integer, primary_key=True), + Column('email_address', String(100), nullable=False), + Column('qpr', String(10), index=True), + UniqueConstraint("email_address", name="uq_email_address") + ) diffs = self._fixture(m1, m2) @@ -70,29 +70,28 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase): else: eq_(diffs, []) - def test_index_becomes_unique(self): m1 = MetaData() m2 = MetaData() Table('order', m1, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True), - Column('user_id', Integer), - UniqueConstraint('order_id', 'user_id', - name='order_order_id_user_id_unique' - ), - Index('order_user_id_amount_idx', 'user_id', 'amount') - ) + Column('order_id', Integer, primary_key=True), + Column('amount', Numeric(10, 2), nullable=True), + Column('user_id', Integer), + UniqueConstraint('order_id', 'user_id', + name='order_order_id_user_id_unique' + ), + Index('order_user_id_amount_idx', 'user_id', 'amount') + ) Table('order', m2, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True), - Column('user_id', Integer), - UniqueConstraint('order_id', 'user_id', - name='order_order_id_user_id_unique' - ), - Index('order_user_id_amount_idx', 'user_id', 'amount', unique=True), - ) + Column('order_id', Integer, primary_key=True), + Column('amount', Numeric(10, 2), nullable=True), + Column('user_id', Integer), + UniqueConstraint('order_id', 'user_id', + name='order_order_id_user_id_unique' + ), + Index('order_user_id_amount_idx', 'user_id', 'amount', unique=True), + ) diffs = self._fixture(m1, m2) eq_(diffs[0][0], "remove_index") @@ -103,21 +102,19 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase): eq_(diffs[1][1].name, "order_user_id_amount_idx") eq_(diffs[1][1].unique, True) - - def test_mismatch_db_named_col_flag(self): m1 = MetaData() m2 = MetaData() Table('item', m1, - Column('x', Integer), - UniqueConstraint('x', name="db_generated_name") - ) + Column('x', Integer), + UniqueConstraint('x', name="db_generated_name") + ) # test mismatch between unique=True and # named uq constraint Table('item', m2, - Column('x', Integer, unique=True) - ) + Column('x', Integer, unique=True) + ) diffs = self._fixture(m1, m2) @@ -127,10 +124,10 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase): m1 = MetaData() m2 = MetaData() Table('extra', m2, - Column('foo', Integer, index=True), - Column('bar', Integer), - Index('newtable_idx', 'bar') - ) + Column('foo', Integer, index=True), + Column('bar', Integer), + Index('newtable_idx', 'bar') + ) diffs = self._fixture(m1, m2) @@ -142,20 +139,19 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase): eq_(diffs[2][0], "add_index") eq_(diffs[2][1].name, "newtable_idx") - def test_named_cols_changed(self): m1 = MetaData() m2 = MetaData() Table('col_change', m1, - Column('x', Integer), - Column('y', Integer), - UniqueConstraint('x', name="nochange") - ) + Column('x', Integer), + Column('y', Integer), + UniqueConstraint('x', name="nochange") + ) Table('col_change', m2, - Column('x', Integer), - Column('y', Integer), - UniqueConstraint('x', 'y', name="nochange") - ) + Column('x', Integer), + Column('y', Integer), + UniqueConstraint('x', 'y', name="nochange") + ) diffs = self._fixture(m1, m2) @@ -173,72 +169,68 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase): m2 = MetaData() Table('nothing_changed', m1, - Column('x', String(20), unique=True, index=True) - ) + Column('x', String(20), unique=True, index=True) + ) Table('nothing_changed', m2, - Column('x', String(20), unique=True, index=True) - ) + Column('x', String(20), unique=True, index=True) + ) diffs = self._fixture(m1, m2) eq_(diffs, []) - def test_nothing_changed_two(self): m1 = MetaData() m2 = MetaData() Table('nothing_changed', m1, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - Column('x', String(20), unique=True), - mysql_engine='InnoDB' - ) + Column('id1', Integer, primary_key=True), + Column('id2', Integer, primary_key=True), + Column('x', String(20), unique=True), + mysql_engine='InnoDB' + ) Table('nothing_changed_related', m1, - Column('id1', Integer), - Column('id2', Integer), - ForeignKeyConstraint(['id1', 'id2'], - ['nothing_changed.id1', 'nothing_changed.id2']), - mysql_engine='InnoDB' - ) + Column('id1', Integer), + Column('id2', Integer), + ForeignKeyConstraint(['id1', 'id2'], + ['nothing_changed.id1', 'nothing_changed.id2']), + mysql_engine='InnoDB' + ) Table('nothing_changed', m2, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - Column('x', String(20), unique=True), - mysql_engine='InnoDB' - ) + Column('id1', Integer, primary_key=True), + Column('id2', Integer, primary_key=True), + Column('x', String(20), unique=True), + mysql_engine='InnoDB' + ) Table('nothing_changed_related', m2, - Column('id1', Integer), - Column('id2', Integer), - ForeignKeyConstraint(['id1', 'id2'], - ['nothing_changed.id1', 'nothing_changed.id2']), - mysql_engine='InnoDB' - ) - + Column('id1', Integer), + Column('id2', Integer), + ForeignKeyConstraint(['id1', 'id2'], + ['nothing_changed.id1', 'nothing_changed.id2']), + mysql_engine='InnoDB' + ) diffs = self._fixture(m1, m2) eq_(diffs, []) - - def test_nothing_changed_index_named_as_column(self): m1 = MetaData() m2 = MetaData() Table('nothing_changed', m1, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - Column('x', String(20)), - Index('x', 'x') - ) + Column('id1', Integer, primary_key=True), + Column('id2', Integer, primary_key=True), + Column('x', String(20)), + Index('x', 'x') + ) Table('nothing_changed', m2, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - Column('x', String(20)), - Index('x', 'x') - ) + Column('id1', Integer, primary_key=True), + Column('id2', Integer, primary_key=True), + Column('x', String(20)), + Index('x', 'x') + ) diffs = self._fixture(m1, m2) eq_(diffs, []) @@ -248,28 +240,28 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase): m2 = MetaData() Table("nothing_changed", m1, - Column('id', Integer, primary_key=True), - Column('other_id', - ForeignKey('nc2.id', + Column('id', Integer, primary_key=True), + Column('other_id', + ForeignKey('nc2.id', name='fk_my_table_other_table' ), - nullable=False), - Column('foo', Integer), - mysql_engine='InnoDB') + nullable=False), + Column('foo', Integer), + mysql_engine='InnoDB') Table('nc2', m1, - Column('id', Integer, primary_key=True), - mysql_engine='InnoDB') + Column('id', Integer, primary_key=True), + mysql_engine='InnoDB') Table("nothing_changed", m2, - Column('id', Integer, primary_key=True), - Column('other_id', ForeignKey('nc2.id', - name='fk_my_table_other_table'), - nullable=False), - Column('foo', Integer), - mysql_engine='InnoDB') + Column('id', Integer, primary_key=True), + Column('other_id', ForeignKey('nc2.id', + name='fk_my_table_other_table'), + nullable=False), + Column('foo', Integer), + mysql_engine='InnoDB') Table('nc2', m2, - Column('id', Integer, primary_key=True), - mysql_engine='InnoDB') + Column('id', Integer, primary_key=True), + mysql_engine='InnoDB') diffs = self._fixture(m1, m2) eq_(diffs, []) @@ -278,18 +270,18 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase): m2 = MetaData() Table('new_idx', m1, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - Column('x', String(20)), - ) + Column('id1', Integer, primary_key=True), + Column('id2', Integer, primary_key=True), + Column('x', String(20)), + ) idx = Index('x', 'x') Table('new_idx', m2, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - Column('x', String(20)), - idx - ) + Column('id1', Integer, primary_key=True), + Column('id2', Integer, primary_key=True), + Column('x', String(20)), + idx + ) diffs = self._fixture(m1, m2) eq_(diffs, [('add_index', idx)]) @@ -300,17 +292,17 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase): idx = Index('x', 'x') Table('new_idx', m1, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - Column('x', String(20)), - idx - ) + Column('id1', Integer, primary_key=True), + Column('id2', Integer, primary_key=True), + Column('x', String(20)), + idx + ) Table('new_idx', m2, - Column('id1', Integer, primary_key=True), - Column('id2', Integer, primary_key=True), - Column('x', String(20)) - ) + Column('id1', Integer, primary_key=True), + Column('id2', Integer, primary_key=True), + Column('x', String(20)) + ) diffs = self._fixture(m1, m2) eq_(diffs[0][0], 'remove_index') @@ -319,38 +311,36 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase): m1 = MetaData() m2 = MetaData() Table('col_change', m1, - Column('x', Integer), - Column('y', Integer), - UniqueConstraint('x') - ) + Column('x', Integer), + Column('y', Integer), + UniqueConstraint('x') + ) Table('col_change', m2, - Column('x', Integer), - Column('y', Integer), - UniqueConstraint('x', 'y') - ) + Column('x', Integer), + Column('y', Integer), + UniqueConstraint('x', 'y') + ) diffs = self._fixture(m1, m2) diffs = set((cmd, - ('x' in obj.name) if obj.name is not None else False) + ('x' in obj.name) if obj.name is not None else False) for cmd, obj in diffs) if self.reports_unnamed_constraints: assert ("remove_constraint", True) in diffs assert ("add_constraint", False) in diffs - - def test_remove_named_unique_index(self): m1 = MetaData() m2 = MetaData() Table('remove_idx', m1, - Column('x', Integer), - Index('xidx', 'x', unique=True) - ) + Column('x', Integer), + Index('xidx', 'x', unique=True) + ) Table('remove_idx', m2, - Column('x', Integer), - ) + Column('x', Integer), + ) diffs = self._fixture(m1, m2) @@ -360,18 +350,17 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase): else: eq_(diffs, []) - def test_remove_named_unique_constraint(self): m1 = MetaData() m2 = MetaData() Table('remove_idx', m1, - Column('x', Integer), - UniqueConstraint('x', name='xidx') - ) + Column('x', Integer), + UniqueConstraint('x', name='xidx') + ) Table('remove_idx', m2, - Column('x', Integer), - ) + Column('x', Integer), + ) diffs = self._fixture(m1, m2) @@ -437,7 +426,6 @@ class AutogenerateUniqueIndexTest(AutogenFixtureTest, TestCase): eq_(diffs, []) - class PGUniqueIndexTest(AutogenerateUniqueIndexTest): reports_unnamed_constraints = True @@ -450,7 +438,7 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest): m2 = MetaData() Table('add_ix', m1, Column('x', String(50)), schema="test_schema") Table('add_ix', m2, Column('x', String(50)), - Index('ix_1', 'x'), schema="test_schema") + Index('ix_1', 'x'), schema="test_schema") diffs = self._fixture(m1, m2, include_schemas=True) eq_(diffs[0][0], "add_index") @@ -460,9 +448,9 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest): m1 = MetaData() m2 = MetaData() Table('add_ix', m1, Column('x', String(50)), Index('ix_1', 'x'), - schema="test_schema") + schema="test_schema") Table('add_ix', m2, Column('x', String(50)), - Index('ix_1', 'x'), schema="test_schema") + Index('ix_1', 'x'), schema="test_schema") diffs = self._fixture(m1, m2, include_schemas=True) eq_(diffs, []) @@ -472,7 +460,7 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest): m2 = MetaData() Table('add_uq', m1, Column('x', String(50)), schema="test_schema") Table('add_uq', m2, Column('x', String(50)), - UniqueConstraint('x', name='ix_1'), schema="test_schema") + UniqueConstraint('x', name='ix_1'), schema="test_schema") diffs = self._fixture(m1, m2, include_schemas=True) eq_(diffs[0][0], "add_constraint") @@ -482,11 +470,11 @@ class PGUniqueIndexTest(AutogenerateUniqueIndexTest): m1 = MetaData() m2 = MetaData() Table('add_uq', m1, Column('x', String(50)), - UniqueConstraint('x', name='ix_1'), - schema="test_schema") + UniqueConstraint('x', name='ix_1'), + schema="test_schema") Table('add_uq', m2, Column('x', String(50)), - UniqueConstraint('x', name='ix_1'), - schema="test_schema") + UniqueConstraint('x', name='ix_1'), + schema="test_schema") diffs = self._fixture(m1, m2, include_schemas=True) eq_(diffs, []) @@ -511,7 +499,7 @@ class MySQLUniqueIndexTest(AutogenerateUniqueIndexTest): def test_removed_idx_index_named_as_column(self): try: super(MySQLUniqueIndexTest, - self).test_removed_idx_index_named_as_column() + self).test_removed_idx_index_named_as_column() except IndexError: assert True else: @@ -521,6 +509,7 @@ class MySQLUniqueIndexTest(AutogenerateUniqueIndexTest): def _get_bind(cls): return db_for_dialect('mysql') + class NoUqReflectionIndexTest(AutogenerateUniqueIndexTest): reports_unique_constraints = False @@ -536,13 +525,13 @@ class NoUqReflectionIndexTest(AutogenerateUniqueIndexTest): def test_unique_not_reported(self): m1 = MetaData() Table('order', m1, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True), - Column('user_id', Integer), - UniqueConstraint('order_id', 'user_id', - name='order_order_id_user_id_unique' - ) - ) + Column('order_id', Integer, primary_key=True), + Column('amount', Numeric(10, 2), nullable=True), + Column('user_id', Integer), + UniqueConstraint('order_id', 'user_id', + name='order_order_id_user_id_unique' + ) + ) diffs = self._fixture(m1, m1) eq_(diffs, []) @@ -550,19 +539,19 @@ class NoUqReflectionIndexTest(AutogenerateUniqueIndexTest): def test_remove_unique_index_not_reported(self): m1 = MetaData() Table('order', m1, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True), - Column('user_id', Integer), - Index('oid_ix', 'order_id', 'user_id', - unique=True - ) - ) + Column('order_id', Integer, primary_key=True), + Column('amount', Numeric(10, 2), nullable=True), + Column('user_id', Integer), + Index('oid_ix', 'order_id', 'user_id', + unique=True + ) + ) m2 = MetaData() Table('order', m2, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True), - Column('user_id', Integer), - ) + Column('order_id', Integer, primary_key=True), + Column('amount', Numeric(10, 2), nullable=True), + Column('user_id', Integer), + ) diffs = self._fixture(m1, m2) eq_(diffs, []) @@ -570,23 +559,24 @@ class NoUqReflectionIndexTest(AutogenerateUniqueIndexTest): def test_remove_plain_index_is_reported(self): m1 = MetaData() Table('order', m1, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True), - Column('user_id', Integer), - Index('oid_ix', 'order_id', 'user_id') - ) + Column('order_id', Integer, primary_key=True), + Column('amount', Numeric(10, 2), nullable=True), + Column('user_id', Integer), + Index('oid_ix', 'order_id', 'user_id') + ) m2 = MetaData() Table('order', m2, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True), - Column('user_id', Integer), - ) + Column('order_id', Integer, primary_key=True), + Column('amount', Numeric(10, 2), nullable=True), + Column('user_id', Integer), + ) diffs = self._fixture(m1, m2) eq_(diffs[0][0], 'remove_index') class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest): + """this test suite simulates the condition where: a. the dialect doesn't report unique constraints @@ -612,8 +602,8 @@ class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest): def get_indexes(self, connection, tablename, **kw): indexes = _get_indexes(self, connection, tablename, **kw) for uq in _get_unique_constraints( - self, connection, tablename, **kw - ): + self, connection, tablename, **kw + ): uq['unique'] = True indexes.append(uq) return indexes @@ -621,4 +611,3 @@ class NoUqReportsIndAsUqTest(NoUqReflectionIndexTest): eng.dialect.get_unique_constraints = unimpl eng.dialect.get_indexes = get_indexes return eng - diff --git a/tests/test_autogen_render.py b/tests/test_autogen_render.py index d253410..901e9f2 100644 --- a/tests/test_autogen_render.py +++ b/tests/test_autogen_render.py @@ -18,7 +18,9 @@ from . import eq_, eq_ignore_whitespace, requires_092, requires_09, requires_094 py3k = sys.version_info >= (3, ) + class AutogenRenderTest(TestCase): + """test individual directives""" @classmethod @@ -38,17 +40,16 @@ class AutogenRenderTest(TestCase): 'dialect': postgresql.dialect() } - def test_render_add_index(self): """ autogenerate.render._add_index """ m = MetaData() t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - ) + Column('id', Integer, primary_key=True), + Column('active', Boolean()), + Column('code', String(255)), + ) idx = Index('test_active_code_idx', t.c.active, t.c.code) eq_ignore_whitespace( autogenerate.render._add_index(idx, self.autogen_context), @@ -62,11 +63,11 @@ class AutogenRenderTest(TestCase): """ m = MetaData() t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - schema='CamelSchema' - ) + Column('id', Integer, primary_key=True), + Column('active', Boolean()), + Column('code', String(255)), + schema='CamelSchema' + ) idx = Index('test_active_code_idx', t.c.active, t.c.code) eq_ignore_whitespace( autogenerate.render._add_index(idx, self.autogen_context), @@ -79,24 +80,24 @@ class AutogenRenderTest(TestCase): m = MetaData() t = Table('t', m, - Column('x', String), - Column('y', String) - ) + Column('x', String), + Column('y', String) + ) idx = Index('foo_idx', t.c.x, t.c.y, - postgresql_where=(t.c.y == 'something')) + postgresql_where=(t.c.y == 'something')) if compat.sqla_08: eq_ignore_whitespace( autogenerate.render._add_index(idx, autogen_context), """op.create_index('foo_idx', 't', ['x', 'y'], unique=False, """ - """postgresql_where=sa.text("t.y = 'something'"))""" + """postgresql_where=sa.text("t.y = 'something'"))""" ) else: eq_ignore_whitespace( autogenerate.render._add_index(idx, autogen_context), """op.create_index('foo_idx', 't', ['x', 'y'], unique=False, """ - """postgresql_where=sa.text('t.y = %(y_1)s'))""" + """postgresql_where=sa.text('t.y = %(y_1)s'))""" ) # def test_render_add_index_func(self): @@ -122,10 +123,10 @@ class AutogenRenderTest(TestCase): """ m = MetaData() t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - ) + Column('id', Integer, primary_key=True), + Column('active', Boolean()), + Column('code', String(255)), + ) idx = Index('test_active_code_idx', t.c.active, t.c.code) eq_ignore_whitespace( autogenerate.render._drop_index(idx, self.autogen_context), @@ -138,16 +139,16 @@ class AutogenRenderTest(TestCase): """ m = MetaData() t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - schema='CamelSchema' - ) + Column('id', Integer, primary_key=True), + Column('active', Boolean()), + Column('code', String(255)), + schema='CamelSchema' + ) idx = Index('test_active_code_idx', t.c.active, t.c.code) eq_ignore_whitespace( autogenerate.render._drop_index(idx, self.autogen_context), "op.drop_index('test_active_code_idx', " + - "table_name='test', schema='CamelSchema')" + "table_name='test', schema='CamelSchema')" ) def test_add_unique_constraint(self): @@ -156,10 +157,10 @@ class AutogenRenderTest(TestCase): """ m = MetaData() t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - ) + Column('id', Integer, primary_key=True), + Column('active', Boolean()), + Column('code', String(255)), + ) uq = UniqueConstraint(t.c.code, name='uq_test_code') eq_ignore_whitespace( autogenerate.render._add_unique_constraint(uq, self.autogen_context), @@ -172,11 +173,11 @@ class AutogenRenderTest(TestCase): """ m = MetaData() t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - schema='CamelSchema' - ) + Column('id', Integer, primary_key=True), + Column('active', Boolean()), + Column('code', String(255)), + schema='CamelSchema' + ) uq = UniqueConstraint(t.c.code, name='uq_test_code') eq_ignore_whitespace( autogenerate.render._add_unique_constraint(uq, self.autogen_context), @@ -189,10 +190,10 @@ class AutogenRenderTest(TestCase): """ m = MetaData() t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - ) + Column('id', Integer, primary_key=True), + Column('active', Boolean()), + Column('code', String(255)), + ) uq = UniqueConstraint(t.c.code, name='uq_test_code') eq_ignore_whitespace( autogenerate.render._drop_constraint(uq, self.autogen_context), @@ -205,11 +206,11 @@ class AutogenRenderTest(TestCase): """ m = MetaData() t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - schema='CamelSchema' - ) + Column('id', Integer, primary_key=True), + Column('active', Boolean()), + Column('code', String(255)), + schema='CamelSchema' + ) uq = UniqueConstraint(t.c.code, name='uq_test_code') eq_ignore_whitespace( autogenerate.render._drop_constraint(uq, self.autogen_context), @@ -219,14 +220,14 @@ class AutogenRenderTest(TestCase): def test_render_table_upgrade(self): m = MetaData() t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('name', Unicode(255)), - Column("address_id", Integer, ForeignKey("address.id")), - Column("timestamp", DATETIME, server_default="NOW()"), - Column("amount", Numeric(5, 2)), - UniqueConstraint("name", name="uq_name"), - UniqueConstraint("timestamp"), - ) + Column('id', Integer, primary_key=True), + Column('name', Unicode(255)), + Column("address_id", Integer, ForeignKey("address.id")), + Column("timestamp", DATETIME, server_default="NOW()"), + Column("amount", Numeric(5, 2)), + UniqueConstraint("name", name="uq_name"), + UniqueConstraint("timestamp"), + ) eq_ignore_whitespace( autogenerate.render._add_table(t, self.autogen_context), "op.create_table('test'," @@ -234,8 +235,8 @@ class AutogenRenderTest(TestCase): "sa.Column('name', sa.Unicode(length=255), nullable=True)," "sa.Column('address_id', sa.Integer(), nullable=True)," "sa.Column('timestamp', sa.DATETIME(), " - "server_default='NOW()', " - "nullable=True)," + "server_default='NOW()', " + "nullable=True)," "sa.Column('amount', sa.Numeric(precision=5, scale=2), nullable=True)," "sa.ForeignKeyConstraint(['address_id'], ['address.id'], )," "sa.PrimaryKeyConstraint('id')," @@ -247,10 +248,10 @@ class AutogenRenderTest(TestCase): def test_render_table_w_schema(self): m = MetaData() t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('q', Integer, ForeignKey('address.id')), - schema='foo' - ) + Column('id', Integer, primary_key=True), + Column('q', Integer, ForeignKey('address.id')), + schema='foo' + ) eq_ignore_whitespace( autogenerate.render._add_table(t, self.autogen_context), "op.create_table('test'," @@ -299,9 +300,9 @@ class AutogenRenderTest(TestCase): def test_render_table_w_fk_schema(self): m = MetaData() t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('q', Integer, ForeignKey('foo.address.id')), - ) + Column('id', Integer, primary_key=True), + Column('q', Integer, ForeignKey('foo.address.id')), + ) eq_ignore_whitespace( autogenerate.render._add_table(t, self.autogen_context), "op.create_table('test'," @@ -315,9 +316,9 @@ class AutogenRenderTest(TestCase): def test_render_table_w_metadata_schema(self): m = MetaData(schema="foo") t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('q', Integer, ForeignKey('address.id')), - ) + Column('id', Integer, primary_key=True), + Column('q', Integer, ForeignKey('address.id')), + ) eq_ignore_whitespace( re.sub(r"u'", "'", autogenerate.render._add_table(t, self.autogen_context)), "op.create_table('test'," @@ -332,9 +333,9 @@ class AutogenRenderTest(TestCase): def test_render_table_w_metadata_schema_override(self): m = MetaData(schema="foo") t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('q', Integer, ForeignKey('bar.address.id')), - ) + Column('id', Integer, primary_key=True), + Column('q', Integer, ForeignKey('bar.address.id')), + ) eq_ignore_whitespace( autogenerate.render._add_table(t, self.autogen_context), "op.create_table('test'," @@ -349,10 +350,10 @@ class AutogenRenderTest(TestCase): def test_render_addtl_args(self): m = MetaData() t = Table('test', m, - Column('id', Integer, primary_key=True), - Column('q', Integer, ForeignKey('bar.address.id')), - sqlite_autoincrement=True, mysql_engine="InnoDB" - ) + Column('id', Integer, primary_key=True), + Column('q', Integer, ForeignKey('bar.address.id')), + sqlite_autoincrement=True, mysql_engine="InnoDB" + ) eq_ignore_whitespace( autogenerate.render._add_table(t, self.autogen_context), "op.create_table('test'," @@ -366,7 +367,7 @@ class AutogenRenderTest(TestCase): def test_render_drop_table(self): eq_( autogenerate.render._drop_table(Table("sometable", MetaData()), - self.autogen_context), + self.autogen_context), "op.drop_table('sometable')" ) @@ -407,26 +408,26 @@ class AutogenRenderTest(TestCase): def test_render_add_column(self): eq_( autogenerate.render._add_column( - None, "foo", Column("x", Integer, server_default="5"), - self.autogen_context), + None, "foo", Column("x", Integer, server_default="5"), + self.autogen_context), "op.add_column('foo', sa.Column('x', sa.Integer(), " - "server_default='5', nullable=True))" + "server_default='5', nullable=True))" ) def test_render_add_column_w_schema(self): eq_( autogenerate.render._add_column( - "foo", "bar", Column("x", Integer, server_default="5"), - self.autogen_context), + "foo", "bar", Column("x", Integer, server_default="5"), + self.autogen_context), "op.add_column('bar', sa.Column('x', sa.Integer(), " - "server_default='5', nullable=True), schema='foo')" + "server_default='5', nullable=True), schema='foo')" ) def test_render_drop_column(self): eq_( autogenerate.render._drop_column( - None, "foo", Column("x", Integer, server_default="5"), - self.autogen_context), + None, "foo", Column("x", Integer, server_default="5"), + self.autogen_context), "op.drop_column('foo', 'x')" ) @@ -434,8 +435,8 @@ class AutogenRenderTest(TestCase): def test_render_drop_column_w_schema(self): eq_( autogenerate.render._drop_column( - "foo", "bar", Column("x", Integer, server_default="5"), - self.autogen_context), + "foo", "bar", Column("x", Integer, server_default="5"), + self.autogen_context), "op.drop_column('bar', 'x', schema='foo')" ) @@ -444,35 +445,35 @@ class AutogenRenderTest(TestCase): eq_( autogenerate.render._render_server_default( "nextval('group_to_perm_group_to_perm_id_seq'::regclass)", - self.autogen_context), + self.autogen_context), '"nextval(\'group_to_perm_group_to_perm_id_seq\'::regclass)"' ) def test_render_col_with_server_default(self): c = Column('updated_at', TIMESTAMP(), - server_default='TIMEZONE("utc", CURRENT_TIMESTAMP)', - nullable=False) + server_default='TIMEZONE("utc", CURRENT_TIMESTAMP)', + nullable=False) result = autogenerate.render._render_column( - c, self.autogen_context - ) + c, self.autogen_context + ) eq_( result, 'sa.Column(\'updated_at\', sa.TIMESTAMP(), ' - 'server_default=\'TIMEZONE("utc", CURRENT_TIMESTAMP)\', ' - 'nullable=False)' + 'server_default=\'TIMEZONE("utc", CURRENT_TIMESTAMP)\', ' + 'nullable=False)' ) def test_render_col_autoinc_false_mysql(self): c = Column('some_key', Integer, primary_key=True, autoincrement=False) Table('some_table', MetaData(), c) result = autogenerate.render._render_column( - c, self.autogen_context - ) + c, self.autogen_context + ) eq_( result, 'sa.Column(\'some_key\', sa.Integer(), ' - 'autoincrement=False, ' - 'nullable=False)' + 'autoincrement=False, ' + 'nullable=False)' ) def test_render_custom(self): @@ -493,14 +494,14 @@ class AutogenRenderTest(TestCase): }} t = Table('t', MetaData(), - Column('x', Integer), - Column('y', Integer), - PrimaryKeyConstraint('x'), - ForeignKeyConstraint(['x'], ['y']) - ) + Column('x', Integer), + Column('y', Integer), + PrimaryKeyConstraint('x'), + ForeignKeyConstraint(['x'], ['y']) + ) result = autogenerate.render._add_table( - t, autogen_context - ) + t, autogen_context + ) eq_( result, """sa.create_table('t', col(x), @@ -510,32 +511,32 @@ render:primary_key\n)""" def test_render_modify_type(self): eq_ignore_whitespace( autogenerate.render._modify_col( - "sometable", "somecolumn", - self.autogen_context, - type_=CHAR(10), existing_type=CHAR(20)), + "sometable", "somecolumn", + self.autogen_context, + type_=CHAR(10), existing_type=CHAR(20)), "op.alter_column('sometable', 'somecolumn', " - "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10))" + "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10))" ) def test_render_modify_type_w_schema(self): eq_ignore_whitespace( autogenerate.render._modify_col( - "sometable", "somecolumn", - self.autogen_context, - type_=CHAR(10), existing_type=CHAR(20), - schema='foo'), + "sometable", "somecolumn", + self.autogen_context, + type_=CHAR(10), existing_type=CHAR(20), + schema='foo'), "op.alter_column('sometable', 'somecolumn', " - "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10), " - "schema='foo')" + "existing_type=sa.CHAR(length=20), type_=sa.CHAR(length=10), " + "schema='foo')" ) def test_render_modify_nullable(self): eq_ignore_whitespace( autogenerate.render._modify_col( - "sometable", "somecolumn", - self.autogen_context, - existing_type=Integer(), - nullable=True), + "sometable", "somecolumn", + self.autogen_context, + existing_type=Integer(), + nullable=True), "op.alter_column('sometable', 'somecolumn', " "existing_type=sa.Integer(), nullable=True)" ) @@ -543,10 +544,10 @@ render:primary_key\n)""" def test_render_modify_nullable_w_schema(self): eq_ignore_whitespace( autogenerate.render._modify_col( - "sometable", "somecolumn", - self.autogen_context, - existing_type=Integer(), - nullable=True, schema='foo'), + "sometable", "somecolumn", + self.autogen_context, + existing_type=Integer(), + nullable=True, schema='foo'), "op.alter_column('sometable', 'somecolumn', " "existing_type=sa.Integer(), nullable=True, schema='foo')" ) @@ -594,13 +595,13 @@ render:primary_key\n)""" m = MetaData() Table('t', m, Column('c', Integer)) t2 = Table('t2', m, Column('c_rem', Integer, - ForeignKey('t.c', name="fk1", use_alter=True))) + ForeignKey('t.c', name="fk1", use_alter=True))) const = list(t2.foreign_keys)[0].constraint eq_ignore_whitespace( autogenerate.render._render_constraint(const, self.autogen_context), "sa.ForeignKeyConstraint(['c_rem'], ['t.c'], " - "name='fk1', use_alter=True)" + "name='fk1', use_alter=True)" ) def test_render_fk_constraint_w_metadata_schema(self): @@ -626,7 +627,6 @@ render:primary_key\n)""" "sa.CheckConstraint('im a constraint', name='cc1')" ) - def test_render_check_constraint_sqlexpr(self): c = column('c') five = literal_column('5') @@ -653,29 +653,27 @@ render:primary_key\n)""" def test_render_modify_nullable_w_default(self): eq_ignore_whitespace( autogenerate.render._modify_col( - "sometable", "somecolumn", - self.autogen_context, - existing_type=Integer(), - existing_server_default="5", - nullable=True), + "sometable", "somecolumn", + self.autogen_context, + existing_type=Integer(), + existing_server_default="5", + nullable=True), "op.alter_column('sometable', 'somecolumn', " "existing_type=sa.Integer(), nullable=True, " "existing_server_default='5')" ) - - def test_render_enum(self): eq_ignore_whitespace( autogenerate.render._repr_type( - Enum("one", "two", "three", name="myenum"), - self.autogen_context), + Enum("one", "two", "three", name="myenum"), + self.autogen_context), "sa.Enum('one', 'two', 'three', name='myenum')" ) eq_ignore_whitespace( autogenerate.render._repr_type( - Enum("one", "two", "three"), - self.autogen_context), + Enum("one", "two", "three"), + self.autogen_context), "sa.Enum('one', 'two', 'three')" ) @@ -696,7 +694,9 @@ render:primary_key\n)""" def test_repr_user_type_user_prefix_None(self): from sqlalchemy.types import UserDefinedType + class MyType(UserDefinedType): + def get_col_spec(self): return "MYTYPE" @@ -717,7 +717,9 @@ render:primary_key\n)""" def test_repr_user_type_user_prefix_present(self): from sqlalchemy.types import UserDefinedType + class MyType(UserDefinedType): + def get_col_spec(self): return "MYTYPE" @@ -755,9 +757,10 @@ render:primary_key\n)""" "mysql.VARCHAR(charset='utf8', national=True, length=20)" ) eq_(autogen_context['imports'], - set(['from sqlalchemy.dialects import mysql']) + set(['from sqlalchemy.dialects import mysql']) ) + class RenderNamingConventionTest(TestCase): @classmethod @@ -771,30 +774,29 @@ class RenderNamingConventionTest(TestCase): 'dialect': postgresql.dialect() } - def setUp(self): convention = { - "ix": 'ix_%(custom)s_%(column_0_label)s', - "uq": "uq_%(custom)s_%(table_name)s_%(column_0_name)s", - "ck": "ck_%(custom)s_%(table_name)s", - "fk": "fk_%(custom)s_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", - "pk": "pk_%(custom)s_%(table_name)s", - "custom": lambda const, table: "ct" + "ix": 'ix_%(custom)s_%(column_0_label)s', + "uq": "uq_%(custom)s_%(table_name)s_%(column_0_name)s", + "ck": "ck_%(custom)s_%(table_name)s", + "fk": "fk_%(custom)s_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(custom)s_%(table_name)s", + "custom": lambda const, table: "ct" } self.metadata = MetaData( - naming_convention=convention - ) + naming_convention=convention + ) def test_schema_type_boolean(self): t = Table('t', self.metadata, Column('c', Boolean(name='xyz'))) eq_ignore_whitespace( autogenerate.render._add_column( - None, "t", t.c.c, - self.autogen_context), + None, "t", t.c.c, + self.autogen_context), "op.add_column('t', " - "sa.Column('c', sa.Boolean(name='xyz'), nullable=True))" + "sa.Column('c', sa.Boolean(name='xyz'), nullable=True))" ) def test_explicit_unique_constraint(self): @@ -819,10 +821,10 @@ class RenderNamingConventionTest(TestCase): def test_render_add_index(self): t = Table('test', self.metadata, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - ) + Column('id', Integer, primary_key=True), + Column('active', Boolean()), + Column('code', String(255)), + ) idx = Index(None, t.c.active, t.c.code) eq_ignore_whitespace( autogenerate.render._add_index(idx, self.autogen_context), @@ -832,10 +834,10 @@ class RenderNamingConventionTest(TestCase): def test_render_drop_index(self): t = Table('test', self.metadata, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - ) + Column('id', Integer, primary_key=True), + Column('active', Boolean()), + Column('code', String(255)), + ) idx = Index(None, t.c.active, t.c.code) eq_ignore_whitespace( autogenerate.render._drop_index(idx, self.autogen_context), @@ -844,11 +846,11 @@ class RenderNamingConventionTest(TestCase): def test_render_add_index_schema(self): t = Table('test', self.metadata, - Column('id', Integer, primary_key=True), - Column('active', Boolean()), - Column('code', String(255)), - schema='CamelSchema' - ) + Column('id', Integer, primary_key=True), + Column('active', Boolean()), + Column('code', String(255)), + schema='CamelSchema' + ) idx = Index(None, t.c.active, t.c.code) eq_ignore_whitespace( autogenerate.render._add_index(idx, self.autogen_context), @@ -856,14 +858,13 @@ class RenderNamingConventionTest(TestCase): "['active', 'code'], unique=False, schema='CamelSchema')" ) - def test_implicit_unique_constraint(self): t = Table('t', self.metadata, Column('c', Integer, unique=True)) uq = [c for c in t.constraints if isinstance(c, UniqueConstraint)][0] eq_ignore_whitespace( autogenerate.render._render_unique_constraint(uq, - self.autogen_context - ), + self.autogen_context + ), "sa.UniqueConstraint('c', name=op.f('uq_ct_t_c'))" ) @@ -872,7 +873,7 @@ class RenderNamingConventionTest(TestCase): eq_ignore_whitespace( autogenerate.render._add_table(t, self.autogen_context), "op.create_table('t',sa.Column('c', sa.Integer(), nullable=False)," - "sa.PrimaryKeyConstraint('c', name=op.f('pk_ct_t')))" + "sa.PrimaryKeyConstraint('c', name=op.f('pk_ct_t')))" ) def test_inline_ck_constraint(self): @@ -880,7 +881,7 @@ class RenderNamingConventionTest(TestCase): eq_ignore_whitespace( autogenerate.render._add_table(t, self.autogen_context), "op.create_table('t',sa.Column('c', sa.Integer(), nullable=True)," - "sa.CheckConstraint('c > 5', name=op.f('ck_ct_t')))" + "sa.CheckConstraint('c > 5', name=op.f('ck_ct_t')))" ) def test_inline_fk(self): @@ -888,7 +889,7 @@ class RenderNamingConventionTest(TestCase): eq_ignore_whitespace( autogenerate.render._add_table(t, self.autogen_context), "op.create_table('t',sa.Column('c', sa.Integer(), nullable=True)," - "sa.ForeignKeyConstraint(['c'], ['q.id'], name=op.f('fk_ct_t_c_q')))" + "sa.ForeignKeyConstraint(['c'], ['q.id'], name=op.f('fk_ct_t_c_q')))" ) def test_render_check_constraint_renamed(self): diff --git a/tests/test_autogenerate.py b/tests/test_autogenerate.py index cc9e118..f52aebb 100644 --- a/tests/test_autogenerate.py +++ b/tests/test_autogenerate.py @@ -14,11 +14,13 @@ from sqlalchemy.engine.reflection import Inspector from alembic import autogenerate from alembic.migration import MigrationContext from . import staging_env, sqlite_db, clear_staging_env, eq_, \ - db_for_dialect + db_for_dialect py3k = sys.version_info >= (3, ) names_in_this_test = set() + + def _default_include_object(obj, name, type_, reflected, compare_to): if type_ == "table": return name in names_in_this_test @@ -29,11 +31,15 @@ _default_object_filters = [ _default_include_object ] from sqlalchemy import event + + @event.listens_for(Table, "after_parent_attach") def new_table(table, parent): names_in_this_test.add(table.name) + class AutogenTest(object): + @classmethod def _get_bind(cls): return sqlite_db() @@ -66,14 +72,16 @@ class AutogenTest(object): 'connection': connection, 'dialect': connection.dialect, 'context': context - } + } @classmethod def teardown_class(cls): cls.m1.drop_all(cls.bind) clear_staging_env() + class AutogenFixtureTest(object): + def _fixture(self, m1, m2, include_schemas=False): self.metadata, model_metadata = m1, m2 self.metadata.create_all(self.bind) @@ -98,13 +106,13 @@ class AutogenFixtureTest(object): 'connection': connection, 'dialect': connection.dialect, 'context': context - } + } diffs = [] autogenerate._produce_net_changes(connection, model_metadata, diffs, autogen_context, object_filters=_default_object_filters, include_schemas=include_schemas - ) + ) return diffs reports_unnamed_constraints = False @@ -124,6 +132,7 @@ class AutogenFixtureTest(object): class AutogenCrossSchemaTest(AutogenTest, TestCase): + @classmethod def _get_bind(cls): cls.test_schema_name = "test_schema" @@ -133,19 +142,19 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase): def _get_db_schema(cls): m = MetaData() Table('t1', m, - Column('x', Integer) - ) + Column('x', Integer) + ) Table('t2', m, - Column('y', Integer), - schema=cls.test_schema_name - ) + Column('y', Integer), + schema=cls.test_schema_name + ) Table('t6', m, - Column('u', Integer) - ) + Column('u', Integer) + ) Table('t7', m, - Column('v', Integer), - schema=cls.test_schema_name - ) + Column('v', Integer), + schema=cls.test_schema_name + ) return m @@ -153,25 +162,26 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase): def _get_model_schema(cls): m = MetaData() Table('t3', m, - Column('q', Integer) - ) + Column('q', Integer) + ) Table('t4', m, - Column('z', Integer), - schema=cls.test_schema_name - ) + Column('z', Integer), + schema=cls.test_schema_name + ) Table('t6', m, - Column('u', Integer) - ) + Column('u', Integer) + ) Table('t7', m, - Column('v', Integer), - schema=cls.test_schema_name - ) + Column('v', Integer), + schema=cls.test_schema_name + ) return m def test_default_schema_omitted_upgrade(self): metadata = self.m2 connection = self.context.bind diffs = [] + def include_object(obj, name, type_, reflected, compare_to): if type_ == "table": return name == "t3" @@ -189,6 +199,7 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase): metadata = self.m2 connection = self.context.bind diffs = [] + def include_object(obj, name, type_, reflected, compare_to): if type_ == "table": return name == "t4" @@ -206,6 +217,7 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase): metadata = self.m2 connection = self.context.bind diffs = [] + def include_object(obj, name, type_, reflected, compare_to): if type_ == "table": return name == "t1" @@ -223,6 +235,7 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase): metadata = self.m2 connection = self.context.bind diffs = [] + def include_object(obj, name, type_, reflected, compare_to): if type_ == "table": return name == "t2" @@ -238,6 +251,7 @@ class AutogenCrossSchemaTest(AutogenTest, TestCase): class AutogenDefaultSchemaTest(AutogenFixtureTest, TestCase): + @classmethod def _get_bind(cls): cls.test_schema_name = "test_schema" @@ -285,7 +299,6 @@ class AutogenDefaultSchemaTest(AutogenFixtureTest, TestCase): Table('a', m2, Column('x', String(50)), schema=default_schema) Table('a', m2, Column('y', String(50)), schema="test_schema") - diffs = self._fixture(m1, m2, include_schemas=True) eq_(len(diffs), 1) eq_(diffs[0][0], "add_table") @@ -303,28 +316,28 @@ class ModelOne(object): m = MetaData(schema=schema) Table('user', m, - Column('id', Integer, primary_key=True), - Column('name', String(50)), - Column('a1', Text), - Column("pw", String(50)) - ) + Column('id', Integer, primary_key=True), + Column('name', String(50)), + Column('a1', Text), + Column("pw", String(50)) + ) Table('address', m, - Column('id', Integer, primary_key=True), - Column('email_address', String(100), nullable=False), - ) + Column('id', Integer, primary_key=True), + Column('email_address', String(100), nullable=False), + ) Table('order', m, - Column('order_id', Integer, primary_key=True), - Column("amount", Numeric(8, 2), nullable=False, - server_default="0"), - CheckConstraint('amount >= 0', name='ck_order_amount') - ) + Column('order_id', Integer, primary_key=True), + Column("amount", Numeric(8, 2), nullable=False, + server_default="0"), + CheckConstraint('amount >= 0', name='ck_order_amount') + ) Table('extra', m, - Column("x", CHAR), - Column('uid', Integer, ForeignKey('user.id')) - ) + Column("x", CHAR), + Column('uid', Integer, ForeignKey('user.id')) + ) return m @@ -335,35 +348,34 @@ class ModelOne(object): m = MetaData(schema=schema) Table('user', m, - Column('id', Integer, primary_key=True), - Column('name', String(50), nullable=False), - Column('a1', Text, server_default="x") - ) + Column('id', Integer, primary_key=True), + Column('name', String(50), nullable=False), + Column('a1', Text, server_default="x") + ) Table('address', m, - Column('id', Integer, primary_key=True), - Column('email_address', String(100), nullable=False), - Column('street', String(50)), - ) + Column('id', Integer, primary_key=True), + Column('email_address', String(100), nullable=False), + Column('street', String(50)), + ) Table('order', m, - Column('order_id', Integer, primary_key=True), - Column('amount', Numeric(10, 2), nullable=True, - server_default="0"), - Column('user_id', Integer, ForeignKey('user.id')), - CheckConstraint('amount > -1', name='ck_order_amount'), - ) + Column('order_id', Integer, primary_key=True), + Column('amount', Numeric(10, 2), nullable=True, + server_default="0"), + Column('user_id', Integer, ForeignKey('user.id')), + CheckConstraint('amount > -1', name='ck_order_amount'), + ) Table('item', m, - Column('id', Integer, primary_key=True), - Column('description', String(100)), - Column('order_id', Integer, ForeignKey('order.order_id')), - CheckConstraint('len(description) > 5') - ) + Column('id', Integer, primary_key=True), + Column('description', String(100)), + Column('order_id', Integer, ForeignKey('order.order_id')), + CheckConstraint('len(description) > 5') + ) return m - class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase): def test_diffs(self): @@ -375,7 +387,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase): autogenerate._produce_net_changes(connection, metadata, diffs, self.autogen_context, object_filters=_default_object_filters, - ) + ) eq_( diffs[0], @@ -415,7 +427,6 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase): eq_(diffs[7][0][5], True) eq_(diffs[7][0][6], False) - def test_render_nothing(self): context = MigrationContext.configure( connection=self.bind.connect(), @@ -431,11 +442,11 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase): autogenerate._produce_migration_diffs(context, template_args, set()) eq_(re.sub(r"u'", "'", template_args['upgrades']), -"""### commands auto generated by Alembic - please adjust! ### + """### commands auto generated by Alembic - please adjust! ### pass ### end Alembic commands ###""") eq_(re.sub(r"u'", "'", template_args['downgrades']), -"""### commands auto generated by Alembic - please adjust! ### + """### commands auto generated by Alembic - please adjust! ### pass ### end Alembic commands ###""") @@ -446,7 +457,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase): autogenerate._produce_migration_diffs(self.context, template_args, set()) eq_(re.sub(r"u'", "'", template_args['upgrades']), -"""### commands auto generated by Alembic - please adjust! ### + """### commands auto generated by Alembic - please adjust! ### op.create_table('item', sa.Column('id', sa.Integer(), nullable=False), sa.Column('description', sa.String(length=100), nullable=True), @@ -474,7 +485,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase): ### end Alembic commands ###""") eq_(re.sub(r"u'", "'", template_args['downgrades']), -"""### commands auto generated by Alembic - please adjust! ### + """### commands auto generated by Alembic - please adjust! ### op.alter_column('user', 'name', existing_type=sa.VARCHAR(length=50), nullable=True) @@ -506,7 +517,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase): 'compare_server_default': True, 'target_metadata': self.m2, 'include_symbol': lambda name, schema=None: - name in ('address', 'order'), + name in ('address', 'order'), 'upgrade_token': "upgrades", 'downgrade_token': "downgrades", 'alembic_module_prefix': 'op.', @@ -517,7 +528,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase): autogenerate._produce_migration_diffs(context, template_args, set()) template_args['upgrades'] = template_args['upgrades'].replace("u'", "'") template_args['downgrades'] = template_args['downgrades'].\ - replace("u'", "'") + replace("u'", "'") assert "alter_column('user'" not in template_args['upgrades'] assert "alter_column('user'" not in template_args['downgrades'] assert "alter_column('order'" in template_args['upgrades'] @@ -559,7 +570,7 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase): template_args['upgrades'] = template_args['upgrades'].replace("u'", "'") template_args['downgrades'] = template_args['downgrades'].\ - replace("u'", "'") + replace("u'", "'") assert "op.create_table('item'" not in template_args['upgrades'] assert "op.create_table('item'" not in template_args['downgrades'] @@ -573,19 +584,19 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase): def test_skip_null_type_comparison_reflected(self): diff = [] autogenerate.compare._compare_type(None, "sometable", "somecol", - Column("somecol", NULLTYPE), - Column("somecol", Integer()), - diff, self.autogen_context - ) + Column("somecol", NULLTYPE), + Column("somecol", Integer()), + diff, self.autogen_context + ) assert not diff def test_skip_null_type_comparison_local(self): diff = [] autogenerate.compare._compare_type(None, "sometable", "somecol", - Column("somecol", Integer()), - Column("somecol", NULLTYPE), - diff, self.autogen_context - ) + Column("somecol", Integer()), + Column("somecol", NULLTYPE), + diff, self.autogen_context + ) assert not diff def test_affinity_typedec(self): @@ -600,10 +611,10 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase): diff = [] autogenerate.compare._compare_type(None, "sometable", "somecol", - Column("somecol", Integer, nullable=True), - Column("somecol", MyType()), - diff, self.autogen_context - ) + Column("somecol", Integer, nullable=True), + Column("somecol", MyType()), + diff, self.autogen_context + ) assert not diff def test_dont_barf_on_already_reflected(self): @@ -613,17 +624,17 @@ class AutogenerateDiffTest(ModelOne, AutogenTest, TestCase): autogenerate.compare._compare_tables( OrderedSet([(None, 'extra'), (None, 'user')]), OrderedSet(), [], inspector, - MetaData(), diffs, self.autogen_context + MetaData(), diffs, self.autogen_context ) eq_( [(rec[0], rec[1].name) for rec in diffs], [('remove_table', 'extra'), ('remove_table', 'user')] ) + class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestCase): schema = "test_schema" - @classmethod def _get_bind(cls): return db_for_dialect('postgresql') @@ -693,14 +704,14 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestCase): ) template_args = {} autogenerate._produce_migration_diffs(context, template_args, set(), - include_symbol=lambda name, schema: False - ) + include_symbol=lambda name, schema: False + ) eq_(re.sub(r"u'", "'", template_args['upgrades']), -"""### commands auto generated by Alembic - please adjust! ### + """### commands auto generated by Alembic - please adjust! ### pass ### end Alembic commands ###""") eq_(re.sub(r"u'", "'", template_args['downgrades']), -"""### commands auto generated by Alembic - please adjust! ### + """### commands auto generated by Alembic - please adjust! ### pass ### end Alembic commands ###""") @@ -709,13 +720,13 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestCase): template_args = {} autogenerate._produce_migration_diffs( - self.context, template_args, set(), - include_object=_default_include_object, - include_schemas=True - ) + self.context, template_args, set(), + include_object=_default_include_object, + include_schemas=True + ) eq_(re.sub(r"u'", "'", template_args['upgrades']), -"""### commands auto generated by Alembic - please adjust! ### + """### commands auto generated by Alembic - please adjust! ### op.create_table('item', sa.Column('id', sa.Integer(), nullable=False), sa.Column('description', sa.String(length=100), nullable=True), @@ -747,7 +758,7 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestCase): ### end Alembic commands ###""" % {"schema": self.schema}) eq_(re.sub(r"u'", "'", template_args['downgrades']), -"""### commands auto generated by Alembic - please adjust! ### + """### commands auto generated by Alembic - please adjust! ### op.alter_column('user', 'name', existing_type=sa.VARCHAR(length=50), nullable=True, @@ -776,10 +787,8 @@ class AutogenerateDiffTestWSchema(ModelOne, AutogenTest, TestCase): ### end Alembic commands ###""" % {"schema": self.schema}) - - - class AutogenerateCustomCompareTypeTest(AutogenTest, TestCase): + @classmethod def _get_db_schema(cls): m = MetaData() @@ -804,7 +813,7 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestCase): diffs = [] autogenerate._produce_net_changes(self.context.bind, self.m2, - diffs, self.autogen_context) + diffs, self.autogen_context) first_table = self.m2.tables['sometable'] first_column = first_table.columns['id'] @@ -827,7 +836,7 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestCase): diffs = [] autogenerate._produce_net_changes(self.context.bind, self.m2, - diffs, self.autogen_context) + diffs, self.autogen_context) eq_(diffs, []) @@ -838,21 +847,22 @@ class AutogenerateCustomCompareTypeTest(AutogenTest, TestCase): diffs = [] autogenerate._produce_net_changes(self.context.bind, self.m2, - diffs, self.autogen_context) + diffs, self.autogen_context) eq_(diffs[0][0][0], 'modify_type') eq_(diffs[1][0][0], 'modify_type') class AutogenKeyTest(AutogenTest, TestCase): + @classmethod def _get_db_schema(cls): m = MetaData() Table('someothertable', m, - Column('id', Integer, primary_key=True), - Column('value', Integer, key="somekey"), - ) + Column('id', Integer, primary_key=True), + Column('value', Integer, key="somekey"), + ) return m @classmethod @@ -860,17 +870,18 @@ class AutogenKeyTest(AutogenTest, TestCase): m = MetaData() Table('sometable', m, - Column('id', Integer, primary_key=True), - Column('value', Integer, key="someotherkey"), - ) + Column('id', Integer, primary_key=True), + Column('value', Integer, key="someotherkey"), + ) Table('someothertable', m, - Column('id', Integer, primary_key=True), - Column('value', Integer, key="somekey"), - Column("othervalue", Integer, key="otherkey") - ) + Column('id', Integer, primary_key=True), + Column('value', Integer, key="somekey"), + Column("othervalue", Integer, key="otherkey") + ) return m symbols = ['someothertable', 'sometable'] + def test_autogen(self): metadata = self.m2 connection = self.context.bind @@ -886,7 +897,9 @@ class AutogenKeyTest(AutogenTest, TestCase): eq_(diffs[1][0], "add_column") eq_(diffs[1][3].key, "otherkey") + class AutogenerateDiffOrderTest(AutogenTest, TestCase): + @classmethod def _get_db_schema(cls): return MetaData() @@ -895,12 +908,12 @@ class AutogenerateDiffOrderTest(AutogenTest, TestCase): def _get_model_schema(cls): m = MetaData() Table('parent', m, - Column('id', Integer, primary_key=True) - ) + Column('id', Integer, primary_key=True) + ) Table('child', m, - Column('parent_id', Integer, ForeignKey('parent.id')), - ) + Column('parent_id', Integer, ForeignKey('parent.id')), + ) return m @@ -925,6 +938,7 @@ class AutogenerateDiffOrderTest(AutogenTest, TestCase): class CompareMetadataTest(ModelOne, AutogenTest, TestCase): + def test_compare_metadata(self): metadata = self.m2 @@ -1035,6 +1049,7 @@ class CompareMetadataTest(ModelOne, AutogenTest, TestCase): eq_(diffs[2][1][5], False) eq_(diffs[2][1][6], True) + class PGCompareMetaData(ModelOne, AutogenTest, TestCase): schema = "test_schema" diff --git a/tests/test_bulk_insert.py b/tests/test_bulk_insert.py index cc56731..13029c7 100644 --- a/tests/test_bulk_insert.py +++ b/tests/test_bulk_insert.py @@ -8,107 +8,121 @@ from sqlalchemy.types import TypeEngine from . import op_fixture, eq_, assert_raises_message + def _table_fixture(dialect, as_sql): context = op_fixture(dialect, as_sql) t1 = table("ins_table", - column('id', Integer), - column('v1', String()), - column('v2', String()), - ) + column('id', Integer), + column('v1', String()), + column('v2', String()), + ) return context, t1 + def _big_t_table_fixture(dialect, as_sql): context = op_fixture(dialect, as_sql) t1 = Table("ins_table", MetaData(), - Column('id', Integer, primary_key=True), - Column('v1', String()), - Column('v2', String()), - ) + Column('id', Integer, primary_key=True), + Column('v1', String()), + Column('v2', String()), + ) return context, t1 + def _test_bulk_insert(dialect, as_sql): context, t1 = _table_fixture(dialect, as_sql) op.bulk_insert(t1, [ - {'id':1, 'v1':'row v1', 'v2':'row v5'}, - {'id':2, 'v1':'row v2', 'v2':'row v6'}, - {'id':3, 'v1':'row v3', 'v2':'row v7'}, - {'id':4, 'v1':'row v4', 'v2':'row v8'}, + {'id': 1, 'v1': 'row v1', 'v2': 'row v5'}, + {'id': 2, 'v1': 'row v2', 'v2': 'row v6'}, + {'id': 3, 'v1': 'row v3', 'v2': 'row v7'}, + {'id': 4, 'v1': 'row v4', 'v2': 'row v8'}, ]) return context + def _test_bulk_insert_single(dialect, as_sql): context, t1 = _table_fixture(dialect, as_sql) op.bulk_insert(t1, [ - {'id':1, 'v1':'row v1', 'v2':'row v5'}, + {'id': 1, 'v1': 'row v1', 'v2': 'row v5'}, ]) return context + def _test_bulk_insert_single_bigt(dialect, as_sql): context, t1 = _big_t_table_fixture(dialect, as_sql) op.bulk_insert(t1, [ - {'id':1, 'v1':'row v1', 'v2':'row v5'}, + {'id': 1, 'v1': 'row v1', 'v2': 'row v5'}, ]) return context + def test_bulk_insert(): context = _test_bulk_insert('default', False) context.assert_( 'INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)' ) + def test_bulk_insert_wrong_cols(): context = op_fixture('postgresql') t1 = table("ins_table", - column('id', Integer), - column('v1', String()), - column('v2', String()), - ) + column('id', Integer), + column('v1', String()), + column('v2', String()), + ) op.bulk_insert(t1, [ - {'v1':'row v1', }, + {'v1': 'row v1', }, ]) context.assert_( 'INSERT INTO ins_table (id, v1, v2) VALUES (%(id)s, %(v1)s, %(v2)s)' ) + def test_bulk_insert_no_rows(): context, t1 = _table_fixture('default', False) op.bulk_insert(t1, []) context.assert_() + def test_bulk_insert_pg(): context = _test_bulk_insert('postgresql', False) context.assert_( 'INSERT INTO ins_table (id, v1, v2) VALUES (%(id)s, %(v1)s, %(v2)s)' ) + def test_bulk_insert_pg_single(): context = _test_bulk_insert_single('postgresql', False) context.assert_( 'INSERT INTO ins_table (id, v1, v2) VALUES (%(id)s, %(v1)s, %(v2)s)' ) + def test_bulk_insert_pg_single_as_sql(): context = _test_bulk_insert_single('postgresql', True) context.assert_( "INSERT INTO ins_table (id, v1, v2) VALUES (1, 'row v1', 'row v5')" ) + def test_bulk_insert_pg_single_big_t_as_sql(): context = _test_bulk_insert_single_bigt('postgresql', True) context.assert_( "INSERT INTO ins_table (id, v1, v2) VALUES (1, 'row v1', 'row v5')" ) + def test_bulk_insert_mssql(): context = _test_bulk_insert('mssql', False) context.assert_( 'INSERT INTO ins_table (id, v1, v2) VALUES (:id, :v1, :v2)' ) + def test_bulk_insert_inline_literal_as_sql(): context = op_fixture('postgresql', True) @@ -136,6 +150,7 @@ def test_bulk_insert_as_sql(): "INSERT INTO ins_table (id, v1, v2) VALUES (4, 'row v4', 'row v8')" ) + def test_bulk_insert_as_sql_pg(): context = _test_bulk_insert('postgresql', True) context.assert_( @@ -145,6 +160,7 @@ def test_bulk_insert_as_sql_pg(): "INSERT INTO ins_table (id, v1, v2) VALUES (4, 'row v4', 'row v8')" ) + def test_bulk_insert_as_sql_mssql(): context = _test_bulk_insert('mssql', True) # SQL server requires IDENTITY_INSERT @@ -159,12 +175,13 @@ def test_bulk_insert_as_sql_mssql(): 'SET IDENTITY_INSERT ins_table OFF' ) + def test_invalid_format(): context, t1 = _table_fixture("sqlite", False) assert_raises_message( TypeError, "List expected", - op.bulk_insert, t1, {"id":5} + op.bulk_insert, t1, {"id": 5} ) assert_raises_message( @@ -173,7 +190,9 @@ def test_invalid_format(): op.bulk_insert, t1, [(5, )] ) + class RoundTripTest(TestCase): + def setUp(self): from sqlalchemy import create_engine from alembic.migration import MigrationContext @@ -188,17 +207,18 @@ class RoundTripTest(TestCase): context = MigrationContext.configure(self.conn) self.op = op.Operations(context) self.t1 = table('foo', - column('id'), - column('data'), - column('x') - ) + column('id'), + column('data'), + column('x') + ) + def tearDown(self): self.conn.close() def test_single_insert_round_trip(self): self.op.bulk_insert(self.t1, - [{'data':"d1", "x":"x1"}] - ) + [{'data': "d1", "x": "x1"}] + ) eq_( self.conn.execute("select id, data, x from foo").fetchall(), @@ -209,9 +229,9 @@ class RoundTripTest(TestCase): def test_bulk_insert_round_trip(self): self.op.bulk_insert(self.t1, [ - {'data':"d1", "x":"x1"}, - {'data':"d2", "x":"x2"}, - {'data':"d3", "x":"x3"}, + {'data': "d1", "x": "x1"}, + {'data': "d2", "x": "x2"}, + {'data': "d3", "x": "x3"}, ]) eq_( @@ -241,4 +261,3 @@ class RoundTripTest(TestCase): (2, "d2"), ] ) - diff --git a/tests/test_command.py b/tests/test_command.py index 53a9538..b550471 100644 --- a/tests/test_command.py +++ b/tests/test_command.py @@ -7,7 +7,6 @@ from io import TextIOWrapper, BytesIO from alembic.script import ScriptDirectory - class StdoutCommandTest(unittest.TestCase): @classmethod diff --git a/tests/test_config.py b/tests/test_config.py index 6164eb9..cd56d13 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -9,6 +9,7 @@ from . import Mock, call from . import eq_, capture_db, assert_raises_message + def test_config_no_file_main_option(): cfg = config.Config() cfg.set_main_option("url", "postgresql://foo/bar") @@ -35,6 +36,7 @@ def test_standalone_op(): op.alter_column("t", "c", nullable=True) eq_(buf, ['ALTER TABLE t ALTER COLUMN c DROP NOT NULL']) + def test_no_script_error(): cfg = config.Config() assert_raises_message( @@ -72,4 +74,3 @@ class OutputEncodingTest(unittest.TestCase): stdout.mock_calls, [call.write('m?il x y'), call.write('\n')] ) - diff --git a/tests/test_environment.py b/tests/test_environment.py index cc5ccb8..ad47cf9 100644 --- a/tests/test_environment.py +++ b/tests/test_environment.py @@ -8,7 +8,9 @@ from . import Mock, call, _no_sql_testing_config, staging_env, clear_staging_env from . import eq_, is_ + class EnvironmentTest(unittest.TestCase): + def setUp(self): staging_env() self.cfg = _no_sql_testing_config() diff --git a/tests/test_mssql.py b/tests/test_mssql.py index 3205959..396692f 100644 --- a/tests/test_mssql.py +++ b/tests/test_mssql.py @@ -11,6 +11,7 @@ from . import op_fixture, capture_context_buffer, \ class FullEnvironmentTests(TestCase): + @classmethod def setup_class(cls): env = staging_env() @@ -41,13 +42,14 @@ class FullEnvironmentTests(TestCase): command.upgrade(self.cfg, self.a, sql=True) assert "BYE" in buf.getvalue() + class OpTest(TestCase): + def test_add_column(self): context = op_fixture('mssql') op.add_column('t1', Column('c1', Integer, nullable=False)) context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL") - def test_add_column_with_default(self): context = op_fixture("mssql") op.add_column('t1', Column('c1', Integer, nullable=False, server_default="12")) @@ -78,8 +80,8 @@ class OpTest(TestCase): context = op_fixture('mssql') from sqlalchemy import Boolean op.alter_column('tests', 'col', - existing_type=Boolean(), - nullable=False) + existing_type=Boolean(), + nullable=False) context.assert_('ALTER TABLE tests ALTER COLUMN col BIT NOT NULL') def test_drop_index(self): @@ -95,7 +97,6 @@ class OpTest(TestCase): context.assert_contains("exec('alter table t1 drop constraint ' + @const_name)") context.assert_contains("ALTER TABLE t1 DROP COLUMN c1") - def test_alter_column_drop_default(self): context = op_fixture('mssql') op.alter_column("t", "c", server_default=None) @@ -186,7 +187,7 @@ class OpTest(TestCase): def test_alter_do_everything(self): context = op_fixture('mssql') op.alter_column("t", "c", new_column_name="c2", nullable=True, - type_=Integer, server_default="5") + type_=Integer, server_default="5") context.assert_( 'ALTER TABLE t ALTER COLUMN c INTEGER NULL', "ALTER TABLE t ADD DEFAULT '5' FOR c", @@ -199,7 +200,7 @@ class OpTest(TestCase): context.assert_contains("EXEC sp_rename 't1', t2") # TODO: when we add schema support - #def test_alter_column_rename_mssql_schema(self): + # def test_alter_column_rename_mssql_schema(self): # context = op_fixture('mssql') # op.alter_column("t", "c", name="x", schema="y") # context.assert_( diff --git a/tests/test_mysql.py b/tests/test_mysql.py index 16b171c..1ad1453 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -7,7 +7,9 @@ from . import op_fixture, assert_raises_message, db_for_dialect, \ staging_env, clear_staging_env from alembic.migration import MigrationContext + class MySQLOpTest(TestCase): + def test_rename_column(self): context = op_fixture('mysql') op.alter_column('t1', 'c1', new_column_name="c2", existing_type=Integer) @@ -18,7 +20,7 @@ class MySQLOpTest(TestCase): def test_rename_column_quotes_needed_one(self): context = op_fixture('mysql') op.alter_column('MyTable', 'ColumnOne', new_column_name="ColumnTwo", - existing_type=Integer) + existing_type=Integer) context.assert_( 'ALTER TABLE `MyTable` CHANGE `ColumnOne` `ColumnTwo` INTEGER NULL' ) @@ -26,7 +28,7 @@ class MySQLOpTest(TestCase): def test_rename_column_quotes_needed_two(self): context = op_fixture('mysql') op.alter_column('my table', 'column one', new_column_name="column two", - existing_type=Integer) + existing_type=Integer) context.assert_( 'ALTER TABLE `my table` CHANGE `column one` `column two` INTEGER NULL' ) @@ -34,7 +36,7 @@ class MySQLOpTest(TestCase): def test_rename_column_serv_default(self): context = op_fixture('mysql') op.alter_column('t1', 'c1', new_column_name="c2", existing_type=Integer, - existing_server_default="q") + existing_server_default="q") context.assert_( "ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL DEFAULT 'q'" ) @@ -42,7 +44,7 @@ class MySQLOpTest(TestCase): def test_rename_column_serv_compiled_default(self): context = op_fixture('mysql') op.alter_column('t1', 'c1', existing_type=Integer, - server_default=func.utc_thing(func.current_timestamp())) + server_default=func.utc_thing(func.current_timestamp())) # this is not a valid MySQL default but the point is to just # test SQL expression rendering context.assert_( @@ -52,7 +54,7 @@ class MySQLOpTest(TestCase): def test_rename_column_autoincrement(self): context = op_fixture('mysql') op.alter_column('t1', 'c1', new_column_name="c2", existing_type=Integer, - existing_autoincrement=True) + existing_autoincrement=True) context.assert_( 'ALTER TABLE t1 CHANGE c1 c2 INTEGER NULL AUTO_INCREMENT' ) @@ -60,7 +62,7 @@ class MySQLOpTest(TestCase): def test_col_add_autoincrement(self): context = op_fixture('mysql') op.alter_column('t1', 'c1', existing_type=Integer, - autoincrement=True) + autoincrement=True) context.assert_( 'ALTER TABLE t1 MODIFY c1 INTEGER NULL AUTO_INCREMENT' ) @@ -68,18 +70,17 @@ class MySQLOpTest(TestCase): def test_col_remove_autoincrement(self): context = op_fixture('mysql') op.alter_column('t1', 'c1', existing_type=Integer, - existing_autoincrement=True, - autoincrement=False) + existing_autoincrement=True, + autoincrement=False) context.assert_( 'ALTER TABLE t1 MODIFY c1 INTEGER NULL' ) - def test_col_dont_remove_server_default(self): context = op_fixture('mysql') op.alter_column('t1', 'c1', existing_type=Integer, - existing_server_default='1', - server_default=False) + existing_server_default='1', + server_default=False) context.assert_() @@ -90,8 +91,6 @@ class MySQLOpTest(TestCase): 'ALTER TABLE t ALTER COLUMN c DROP DEFAULT' ) - - def test_alter_column_modify_default(self): context = op_fixture('mysql') # notice we dont need the existing type on this one... @@ -110,7 +109,7 @@ class MySQLOpTest(TestCase): def test_col_not_nullable_existing_serv_default(self): context = op_fixture('mysql') op.alter_column('t1', 'c1', nullable=False, existing_type=Integer, - existing_server_default='5') + existing_server_default='5') context.assert_( "ALTER TABLE t1 MODIFY c1 INTEGER NOT NULL DEFAULT '5'" ) @@ -191,7 +190,9 @@ class MySQLOpTest(TestCase): op.drop_constraint, "f1", "t1" ) + class MySQLDefaultCompareTest(TestCase): + @classmethod def setup_class(cls): cls.bind = db_for_dialect("mysql") @@ -209,7 +210,7 @@ class MySQLDefaultCompareTest(TestCase): 'connection': connection, 'dialect': connection.dialect, 'context': context - } + } @classmethod def teardown_class(cls): @@ -228,11 +229,11 @@ class MySQLDefaultCompareTest(TestCase): alternate = txt expected = False t = Table("test", self.metadata, - Column("somecol", type_, server_default=text(txt) if txt else None) - ) + Column("somecol", type_, server_default=text(txt) if txt else None) + ) t2 = Table("test", MetaData(), - Column("somecol", type_, server_default=text(alternate)) - ) + Column("somecol", type_, server_default=text(alternate)) + ) assert self._compare_default( t, t2, t2.c.somecol, alternate ) is expected @@ -263,4 +264,3 @@ class MySQLDefaultCompareTest(TestCase): TIMESTAMP(), None, "CURRENT_TIMESTAMP", ) - diff --git a/tests/test_offline_environment.py b/tests/test_offline_environment.py index 7026e8c..9623bcc 100644 --- a/tests/test_offline_environment.py +++ b/tests/test_offline_environment.py @@ -9,6 +9,7 @@ from . import clear_staging_env, staging_env, \ class OfflineEnvironmentTest(TestCase): + def setUp(self): env = staging_env() self.cfg = _no_sql_testing_config() @@ -33,7 +34,6 @@ assert context.requires_connection() command.upgrade(self.cfg, a) command.downgrade(self.cfg, a) - def test_starting_rev_post_context(self): env_file_fixture(""" context.configure(dialect_name='sqlite', starting_rev='x') diff --git a/tests/test_op.py b/tests/test_op.py index eaa0d5d..8c4e964 100644 --- a/tests/test_op.py +++ b/tests/test_op.py @@ -1,7 +1,7 @@ """Test against the builders in the op.* module.""" from sqlalchemy import Integer, Column, ForeignKey, \ - Table, String, Boolean, MetaData, CheckConstraint + Table, String, Boolean, MetaData, CheckConstraint from sqlalchemy.sql import column, func, text from sqlalchemy import event @@ -9,6 +9,7 @@ from alembic import op from . import op_fixture, assert_raises_message, requires_094, eq_ from . import mock + @event.listens_for(Table, "after_parent_attach") def _add_cols(table, metadata): if table.name == "tbl_with_auto_appended_column": @@ -20,16 +21,19 @@ def test_rename_table(): op.rename_table('t1', 't2') context.assert_("ALTER TABLE t1 RENAME TO t2") + def test_rename_table_schema(): context = op_fixture() op.rename_table('t1', 't2', schema="foo") context.assert_("ALTER TABLE foo.t1 RENAME TO foo.t2") + def test_rename_table_postgresql(): context = op_fixture("postgresql") op.rename_table('t1', 't2') context.assert_("ALTER TABLE t1 RENAME TO t2") + def test_rename_table_schema_postgresql(): context = op_fixture("postgresql") op.rename_table('t1', 't2', schema="foo") @@ -76,6 +80,7 @@ def test_create_index_postgresql_expressions(): "CREATE INDEX geocoded ON locations (lower(coordinates)) " "WHERE locations.coordinates != Null") + def test_create_index_postgresql_where(): context = op_fixture("postgresql") op.create_index( @@ -84,31 +89,36 @@ def test_create_index_postgresql_where(): ['coordinates'], postgresql_where=text("locations.coordinates != Null")) context.assert_( - "CREATE INDEX geocoded ON locations (coordinates) " - "WHERE locations.coordinates != Null") + "CREATE INDEX geocoded ON locations (coordinates) " + "WHERE locations.coordinates != Null") + def test_add_column(): context = op_fixture() op.add_column('t1', Column('c1', Integer, nullable=False)) context.assert_("ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL") + def test_add_column_schema(): context = op_fixture() op.add_column('t1', Column('c1', Integer, nullable=False), schema="foo") context.assert_("ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL") + def test_add_column_with_default(): context = op_fixture() op.add_column('t1', Column('c1', Integer, nullable=False, server_default="12")) context.assert_("ALTER TABLE t1 ADD COLUMN c1 INTEGER DEFAULT '12' NOT NULL") + def test_add_column_schema_with_default(): context = op_fixture() op.add_column('t1', - Column('c1', Integer, nullable=False, server_default="12"), - schema='foo') + Column('c1', Integer, nullable=False, server_default="12"), + schema='foo') context.assert_("ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER DEFAULT '12' NOT NULL") + def test_add_column_fk(): context = op_fixture() op.add_column('t1', Column('c1', Integer, ForeignKey('c2.id'), nullable=False)) @@ -117,16 +127,18 @@ def test_add_column_fk(): "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES c2 (id)" ) + def test_add_column_schema_fk(): context = op_fixture() op.add_column('t1', - Column('c1', Integer, ForeignKey('c2.id'), nullable=False), - schema='foo') + Column('c1', Integer, ForeignKey('c2.id'), nullable=False), + schema='foo') context.assert_( "ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL", "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES c2 (id)" ) + def test_add_column_schema_type(): """Test that a schema type generates its constraints....""" context = op_fixture() @@ -146,6 +158,7 @@ def test_add_column_schema_schema_type(): 'ALTER TABLE foo.t1 ADD CHECK (c1 IN (0, 1))' ) + def test_add_column_schema_type_checks_rule(): """Test that a schema type doesn't generate a constraint based on check rule.""" @@ -155,6 +168,7 @@ def test_add_column_schema_type_checks_rule(): 'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL', ) + def test_add_column_fk_self_referential(): context = op_fixture() op.add_column('t1', Column('c1', Integer, ForeignKey('t1.c2'), nullable=False)) @@ -163,44 +177,50 @@ def test_add_column_fk_self_referential(): "ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES t1 (c2)" ) + def test_add_column_schema_fk_self_referential(): context = op_fixture() op.add_column('t1', - Column('c1', Integer, ForeignKey('foo.t1.c2'), nullable=False), - schema='foo') + Column('c1', Integer, ForeignKey('foo.t1.c2'), nullable=False), + schema='foo') context.assert_( "ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL", "ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES foo.t1 (c2)" ) + def test_add_column_fk_schema(): context = op_fixture() op.add_column('t1', Column('c1', Integer, ForeignKey('remote.t2.c2'), nullable=False)) context.assert_( - 'ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL', - 'ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)' + 'ALTER TABLE t1 ADD COLUMN c1 INTEGER NOT NULL', + 'ALTER TABLE t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)' ) + def test_add_column_schema_fk_schema(): context = op_fixture() op.add_column('t1', - Column('c1', Integer, ForeignKey('remote.t2.c2'), nullable=False), - schema='foo') + Column('c1', Integer, ForeignKey('remote.t2.c2'), nullable=False), + schema='foo') context.assert_( - 'ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL', - 'ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)' + 'ALTER TABLE foo.t1 ADD COLUMN c1 INTEGER NOT NULL', + 'ALTER TABLE foo.t1 ADD FOREIGN KEY(c1) REFERENCES remote.t2 (c2)' ) + def test_drop_column(): context = op_fixture() op.drop_column('t1', 'c1') context.assert_("ALTER TABLE t1 DROP COLUMN c1") + def test_drop_column_schema(): context = op_fixture() op.drop_column('t1', 'c1', schema='foo') context.assert_("ALTER TABLE foo.t1 DROP COLUMN c1") + def test_alter_column_nullable(): context = op_fixture() op.alter_column("t", "c", nullable=True) @@ -210,6 +230,7 @@ def test_alter_column_nullable(): "ALTER TABLE t ALTER COLUMN c DROP NOT NULL" ) + def test_alter_column_schema_nullable(): context = op_fixture() op.alter_column("t", "c", nullable=True, schema='foo') @@ -219,6 +240,7 @@ def test_alter_column_schema_nullable(): "ALTER TABLE foo.t ALTER COLUMN c DROP NOT NULL" ) + def test_alter_column_not_nullable(): context = op_fixture() op.alter_column("t", "c", nullable=False) @@ -228,6 +250,7 @@ def test_alter_column_not_nullable(): "ALTER TABLE t ALTER COLUMN c SET NOT NULL" ) + def test_alter_column_schema_not_nullable(): context = op_fixture() op.alter_column("t", "c", nullable=False, schema='foo') @@ -237,6 +260,7 @@ def test_alter_column_schema_not_nullable(): "ALTER TABLE foo.t ALTER COLUMN c SET NOT NULL" ) + def test_alter_column_rename(): context = op_fixture() op.alter_column("t", "c", new_column_name="x") @@ -244,6 +268,7 @@ def test_alter_column_rename(): "ALTER TABLE t RENAME c TO x" ) + def test_alter_column_schema_rename(): context = op_fixture() op.alter_column("t", "c", new_column_name="x", schema='foo') @@ -251,6 +276,7 @@ def test_alter_column_schema_rename(): "ALTER TABLE foo.t RENAME c TO x" ) + def test_alter_column_type(): context = op_fixture() op.alter_column("t", "c", type_=String(50)) @@ -258,6 +284,7 @@ def test_alter_column_type(): 'ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(50)' ) + def test_alter_column_schema_type(): context = op_fixture() op.alter_column("t", "c", type_=String(50), schema='foo') @@ -265,6 +292,7 @@ def test_alter_column_schema_type(): 'ALTER TABLE foo.t ALTER COLUMN c TYPE VARCHAR(50)' ) + def test_alter_column_set_default(): context = op_fixture() op.alter_column("t", "c", server_default="q") @@ -272,6 +300,7 @@ def test_alter_column_set_default(): "ALTER TABLE t ALTER COLUMN c SET DEFAULT 'q'" ) + def test_alter_column_schema_set_default(): context = op_fixture() op.alter_column("t", "c", server_default="q", schema='foo') @@ -279,23 +308,26 @@ def test_alter_column_schema_set_default(): "ALTER TABLE foo.t ALTER COLUMN c SET DEFAULT 'q'" ) + def test_alter_column_set_compiled_default(): context = op_fixture() op.alter_column("t", "c", - server_default=func.utc_thing(func.current_timestamp())) + server_default=func.utc_thing(func.current_timestamp())) context.assert_( "ALTER TABLE t ALTER COLUMN c SET DEFAULT utc_thing(CURRENT_TIMESTAMP)" ) + def test_alter_column_schema_set_compiled_default(): context = op_fixture() op.alter_column("t", "c", - server_default=func.utc_thing(func.current_timestamp()), - schema='foo') + server_default=func.utc_thing(func.current_timestamp()), + schema='foo') context.assert_( "ALTER TABLE foo.t ALTER COLUMN c SET DEFAULT utc_thing(CURRENT_TIMESTAMP)" ) + def test_alter_column_drop_default(): context = op_fixture() op.alter_column("t", "c", server_default=None) @@ -303,6 +335,7 @@ def test_alter_column_drop_default(): 'ALTER TABLE t ALTER COLUMN c DROP DEFAULT' ) + def test_alter_column_schema_drop_default(): context = op_fixture() op.alter_column("t", "c", server_default=None, schema='foo') @@ -319,6 +352,7 @@ def test_alter_column_schema_type_unnamed(): 'ALTER TABLE t ADD CHECK (c IN (0, 1))' ) + def test_alter_column_schema_schema_type_unnamed(): context = op_fixture('mssql') op.alter_column("t", "c", type_=Boolean(), schema='foo') @@ -327,6 +361,7 @@ def test_alter_column_schema_schema_type_unnamed(): 'ALTER TABLE foo.t ADD CHECK (c IN (0, 1))' ) + def test_alter_column_schema_type_named(): context = op_fixture('mssql') op.alter_column("t", "c", type_=Boolean(name="xyz")) @@ -335,6 +370,7 @@ def test_alter_column_schema_type_named(): 'ALTER TABLE t ADD CONSTRAINT xyz CHECK (c IN (0, 1))' ) + def test_alter_column_schema_schema_type_named(): context = op_fixture('mssql') op.alter_column("t", "c", type_=Boolean(name="xyz"), schema='foo') @@ -343,6 +379,7 @@ def test_alter_column_schema_schema_type_named(): 'ALTER TABLE foo.t ADD CONSTRAINT xyz CHECK (c IN (0, 1))' ) + def test_alter_column_schema_type_existing_type(): context = op_fixture('mssql') op.alter_column("t", "c", type_=String(10), existing_type=Boolean(name="xyz")) @@ -351,15 +388,17 @@ def test_alter_column_schema_type_existing_type(): 'ALTER TABLE t ALTER COLUMN c VARCHAR(10)' ) + def test_alter_column_schema_schema_type_existing_type(): context = op_fixture('mssql') op.alter_column("t", "c", type_=String(10), - existing_type=Boolean(name="xyz"), schema='foo') + existing_type=Boolean(name="xyz"), schema='foo') context.assert_( 'ALTER TABLE foo.t DROP CONSTRAINT xyz', 'ALTER TABLE foo.t ALTER COLUMN c VARCHAR(10)' ) + def test_alter_column_schema_type_existing_type_no_const(): context = op_fixture('postgresql') op.alter_column("t", "c", type_=String(10), existing_type=Boolean()) @@ -367,14 +406,16 @@ def test_alter_column_schema_type_existing_type_no_const(): 'ALTER TABLE t ALTER COLUMN c TYPE VARCHAR(10)' ) + def test_alter_column_schema_schema_type_existing_type_no_const(): context = op_fixture('postgresql') op.alter_column("t", "c", type_=String(10), existing_type=Boolean(), - schema='foo') + schema='foo') context.assert_( 'ALTER TABLE foo.t ALTER COLUMN c TYPE VARCHAR(10)' ) + def test_alter_column_schema_type_existing_type_no_new_type(): context = op_fixture('postgresql') op.alter_column("t", "c", nullable=False, existing_type=Boolean()) @@ -382,94 +423,104 @@ def test_alter_column_schema_type_existing_type_no_new_type(): 'ALTER TABLE t ALTER COLUMN c SET NOT NULL' ) + def test_alter_column_schema_schema_type_existing_type_no_new_type(): context = op_fixture('postgresql') op.alter_column("t", "c", nullable=False, existing_type=Boolean(), - schema='foo') + schema='foo') context.assert_( 'ALTER TABLE foo.t ALTER COLUMN c SET NOT NULL' ) + def test_add_foreign_key(): context = op_fixture() op.create_foreign_key('fk_test', 't1', 't2', - ['foo', 'bar'], ['bat', 'hoho']) + ['foo', 'bar'], ['bat', 'hoho']) context.assert_( "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) " - "REFERENCES t2 (bat, hoho)" + "REFERENCES t2 (bat, hoho)" ) + def test_add_foreign_key_schema(): context = op_fixture() op.create_foreign_key('fk_test', 't1', 't2', - ['foo', 'bar'], ['bat', 'hoho'], - source_schema='foo2', referent_schema='bar2') + ['foo', 'bar'], ['bat', 'hoho'], + source_schema='foo2', referent_schema='bar2') context.assert_( "ALTER TABLE foo2.t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) " - "REFERENCES bar2.t2 (bat, hoho)" + "REFERENCES bar2.t2 (bat, hoho)" ) + def test_add_foreign_key_onupdate(): context = op_fixture() op.create_foreign_key('fk_test', 't1', 't2', - ['foo', 'bar'], ['bat', 'hoho'], - onupdate='CASCADE') + ['foo', 'bar'], ['bat', 'hoho'], + onupdate='CASCADE') context.assert_( "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) " - "REFERENCES t2 (bat, hoho) ON UPDATE CASCADE" + "REFERENCES t2 (bat, hoho) ON UPDATE CASCADE" ) + def test_add_foreign_key_ondelete(): context = op_fixture() op.create_foreign_key('fk_test', 't1', 't2', - ['foo', 'bar'], ['bat', 'hoho'], - ondelete='CASCADE') + ['foo', 'bar'], ['bat', 'hoho'], + ondelete='CASCADE') context.assert_( "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) " - "REFERENCES t2 (bat, hoho) ON DELETE CASCADE" + "REFERENCES t2 (bat, hoho) ON DELETE CASCADE" ) + def test_add_foreign_key_deferrable(): context = op_fixture() op.create_foreign_key('fk_test', 't1', 't2', - ['foo', 'bar'], ['bat', 'hoho'], - deferrable=True) + ['foo', 'bar'], ['bat', 'hoho'], + deferrable=True) context.assert_( "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) " - "REFERENCES t2 (bat, hoho) DEFERRABLE" + "REFERENCES t2 (bat, hoho) DEFERRABLE" ) + def test_add_foreign_key_initially(): context = op_fixture() op.create_foreign_key('fk_test', 't1', 't2', - ['foo', 'bar'], ['bat', 'hoho'], - initially='INITIAL') + ['foo', 'bar'], ['bat', 'hoho'], + initially='INITIAL') context.assert_( "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) " - "REFERENCES t2 (bat, hoho) INITIALLY INITIAL" + "REFERENCES t2 (bat, hoho) INITIALLY INITIAL" ) + def test_add_foreign_key_match(): context = op_fixture() op.create_foreign_key('fk_test', 't1', 't2', - ['foo', 'bar'], ['bat', 'hoho'], - match='SIMPLE') + ['foo', 'bar'], ['bat', 'hoho'], + match='SIMPLE') context.assert_( "ALTER TABLE t1 ADD CONSTRAINT fk_test FOREIGN KEY(foo, bar) " - "REFERENCES t2 (bat, hoho) MATCH SIMPLE" + "REFERENCES t2 (bat, hoho) MATCH SIMPLE" ) + def test_add_foreign_key_dialect_kw(): context = op_fixture() with mock.patch("alembic.operations.sa_schema.ForeignKeyConstraint") as fkc: op.create_foreign_key('fk_test', 't1', 't2', - ['foo', 'bar'], ['bat', 'hoho'], - foobar_arg='xyz') + ['foo', 'bar'], ['bat', 'hoho'], + foobar_arg='xyz') eq_(fkc.mock_calls[0], - mock.call(['foo', 'bar'], ['t2.bat', 't2.hoho'], - onupdate=None, ondelete=None, name='fk_test', - foobar_arg='xyz', - deferrable=None, initially=None, match=None)) + mock.call(['foo', 'bar'], ['t2.bat', 't2.hoho'], + onupdate=None, ondelete=None, name='fk_test', + foobar_arg='xyz', + deferrable=None, initially=None, match=None)) + def test_add_foreign_key_self_referential(): context = op_fixture() @@ -479,6 +530,7 @@ def test_add_foreign_key_self_referential(): "FOREIGN KEY(foo) REFERENCES t1 (bar)" ) + def test_add_primary_key_constraint(): context = op_fixture() op.create_primary_key("pk_test", "t1", ["foo", "bar"]) @@ -486,6 +538,7 @@ def test_add_primary_key_constraint(): "ALTER TABLE t1 ADD CONSTRAINT pk_test PRIMARY KEY (foo, bar)" ) + def test_add_primary_key_constraint_schema(): context = op_fixture() op.create_primary_key("pk_test", "t1", ["foo"], schema="bar") @@ -506,6 +559,7 @@ def test_add_check_constraint(): "CHECK (len(name) > 5)" ) + def test_add_check_constraint_schema(): context = op_fixture() op.create_check_constraint( @@ -519,6 +573,7 @@ def test_add_check_constraint_schema(): "CHECK (len(name) > 5)" ) + def test_add_unique_constraint(): context = op_fixture() op.create_unique_constraint('uk_test', 't1', ['foo', 'bar']) @@ -526,6 +581,7 @@ def test_add_unique_constraint(): "ALTER TABLE t1 ADD CONSTRAINT uk_test UNIQUE (foo, bar)" ) + def test_add_unique_constraint_schema(): context = op_fixture() op.create_unique_constraint('uk_test', 't1', ['foo', 'bar'], schema='foo') @@ -541,6 +597,7 @@ def test_drop_constraint(): "ALTER TABLE t1 DROP CONSTRAINT foo_bar_bat" ) + def test_drop_constraint_schema(): context = op_fixture() op.drop_constraint('foo_bar_bat', 't1', schema='foo') @@ -548,6 +605,7 @@ def test_drop_constraint_schema(): "ALTER TABLE foo.t1 DROP CONSTRAINT foo_bar_bat" ) + def test_create_index(): context = op_fixture() op.create_index('ik_test', 't1', ['foo', 'bar']) @@ -564,10 +622,11 @@ def test_create_index_table_col_event(): "CREATE INDEX ik_test ON tbl_with_auto_appended_column (foo, bar)" ) + def test_add_unique_constraint_col_event(): context = op_fixture() op.create_unique_constraint('ik_test', - 'tbl_with_auto_appended_column', ['foo', 'bar']) + 'tbl_with_auto_appended_column', ['foo', 'bar']) context.assert_( "ALTER TABLE tbl_with_auto_appended_column " "ADD CONSTRAINT ik_test UNIQUE (foo, bar)" @@ -581,6 +640,7 @@ def test_create_index_schema(): "CREATE INDEX ik_test ON foo.t1 (foo, bar)" ) + def test_drop_index(): context = op_fixture() op.drop_index('ik_test') @@ -588,6 +648,7 @@ def test_drop_index(): "DROP INDEX ik_test" ) + def test_drop_index_schema(): context = op_fixture() op.drop_index('ik_test', schema='foo') @@ -595,6 +656,7 @@ def test_drop_index_schema(): "DROP INDEX foo.ik_test" ) + def test_drop_table(): context = op_fixture() op.drop_table('tb_test') @@ -602,6 +664,7 @@ def test_drop_table(): "DROP TABLE tb_test" ) + def test_drop_table_schema(): context = op_fixture() op.drop_table('tb_test', schema='foo') @@ -609,6 +672,7 @@ def test_drop_table_schema(): "DROP TABLE foo.tb_test" ) + def test_create_table_selfref(): context = op_fixture() op.create_table( @@ -618,12 +682,13 @@ def test_create_table_selfref(): ) context.assert_( "CREATE TABLE some_table (" - "id INTEGER NOT NULL, " - "st_id INTEGER, " - "PRIMARY KEY (id), " - "FOREIGN KEY(st_id) REFERENCES some_table (id))" + "id INTEGER NOT NULL, " + "st_id INTEGER, " + "PRIMARY KEY (id), " + "FOREIGN KEY(st_id) REFERENCES some_table (id))" ) + def test_create_table_fk_and_schema(): context = op_fixture() op.create_table( @@ -634,12 +699,13 @@ def test_create_table_fk_and_schema(): ) context.assert_( "CREATE TABLE schema.some_table (" - "id INTEGER NOT NULL, " - "foo_id INTEGER, " - "PRIMARY KEY (id), " - "FOREIGN KEY(foo_id) REFERENCES foo (id))" + "id INTEGER NOT NULL, " + "foo_id INTEGER, " + "PRIMARY KEY (id), " + "FOREIGN KEY(foo_id) REFERENCES foo (id))" ) + def test_create_table_no_pk(): context = op_fixture() op.create_table( @@ -652,6 +718,7 @@ def test_create_table_no_pk(): "CREATE TABLE some_table (x INTEGER, y INTEGER, z INTEGER)" ) + def test_create_table_two_fk(): context = op_fixture() op.create_table( @@ -662,38 +729,40 @@ def test_create_table_two_fk(): ) context.assert_( "CREATE TABLE some_table (" - "id INTEGER NOT NULL, " - "foo_id INTEGER, " - "foo_bar INTEGER, " - "PRIMARY KEY (id), " - "FOREIGN KEY(foo_id) REFERENCES foo (id), " - "FOREIGN KEY(foo_bar) REFERENCES foo (bar))" + "id INTEGER NOT NULL, " + "foo_id INTEGER, " + "foo_bar INTEGER, " + "PRIMARY KEY (id), " + "FOREIGN KEY(foo_id) REFERENCES foo (id), " + "FOREIGN KEY(foo_bar) REFERENCES foo (bar))" ) + def test_inline_literal(): context = op_fixture() from sqlalchemy.sql import table, column from sqlalchemy import String, Integer account = table('account', - column('name', String), - column('id', Integer) - ) + column('name', String), + column('id', Integer) + ) op.execute( - account.update().\ - where(account.c.name == op.inline_literal('account 1')).\ - values({'name': op.inline_literal('account 2')}) - ) + account.update(). + where(account.c.name == op.inline_literal('account 1')). + values({'name': op.inline_literal('account 2')}) + ) op.execute( - account.update().\ - where(account.c.id == op.inline_literal(1)).\ - values({'id': op.inline_literal(2)}) - ) + account.update(). + where(account.c.id == op.inline_literal(1)). + values({'id': op.inline_literal(2)}) + ) context.assert_( "UPDATE account SET name='account 2' WHERE account.name = 'account 1'", "UPDATE account SET id=2 WHERE account.id = 1" ) + def test_cant_op(): if hasattr(op, '_proxy'): del op._proxy @@ -733,4 +802,3 @@ def test_naming_changes(): r"Unknown arguments: badarg\d, badarg\d", op.alter_column, "t", "c", badarg1="x", badarg2="y" ) - diff --git a/tests/test_op_naming_convention.py b/tests/test_op_naming_convention.py index b0b5b76..3f80ecf 100644 --- a/tests/test_op_naming_convention.py +++ b/tests/test_op_naming_convention.py @@ -1,16 +1,17 @@ from sqlalchemy import Integer, Column, ForeignKey, \ - Table, String, Boolean, MetaData, CheckConstraint + Table, String, Boolean, MetaData, CheckConstraint from sqlalchemy.sql import column, func, text from sqlalchemy import event from alembic import op from . import op_fixture, assert_raises_message, requires_094 + @requires_094 def test_add_check_constraint(): context = op_fixture(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s" - }) + "ck": "ck_%(table_name)s_%(constraint_name)s" + }) op.create_check_constraint( "foo", "user_table", @@ -21,11 +22,12 @@ def test_add_check_constraint(): "CHECK (len(name) > 5)" ) + @requires_094 def test_add_check_constraint_name_is_none(): context = op_fixture(naming_convention={ - "ck": "ck_%(table_name)s_foo" - }) + "ck": "ck_%(table_name)s_foo" + }) op.create_check_constraint( None, "user_table", @@ -36,11 +38,12 @@ def test_add_check_constraint_name_is_none(): "CHECK (len(name) > 5)" ) + @requires_094 def test_add_unique_constraint_name_is_none(): context = op_fixture(naming_convention={ - "uq": "uq_%(table_name)s_foo" - }) + "uq": "uq_%(table_name)s_foo" + }) op.create_unique_constraint( None, "user_table", @@ -54,8 +57,8 @@ def test_add_unique_constraint_name_is_none(): @requires_094 def test_add_index_name_is_none(): context = op_fixture(naming_convention={ - "ix": "ix_%(table_name)s_foo" - }) + "ix": "ix_%(table_name)s_foo" + }) op.create_index( None, "user_table", @@ -66,7 +69,6 @@ def test_add_index_name_is_none(): ) - @requires_094 def test_add_check_constraint_already_named_from_schema(): m1 = MetaData(naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) @@ -74,7 +76,7 @@ def test_add_check_constraint_already_named_from_schema(): Table('t', m1, Column('x'), ck) context = op_fixture( - naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) op.create_table( "some_table", @@ -85,10 +87,11 @@ def test_add_check_constraint_already_named_from_schema(): "(x INTEGER CONSTRAINT ck_t_cc1 CHECK (im a constraint))" ) + @requires_094 def test_add_check_constraint_inline_on_table(): context = op_fixture( - naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) op.create_table( "some_table", Column('x', Integer), @@ -99,10 +102,11 @@ def test_add_check_constraint_inline_on_table(): "(x INTEGER, CONSTRAINT ck_some_table_cc1 CHECK (im a constraint))" ) + @requires_094 def test_add_check_constraint_inline_on_table_w_f(): context = op_fixture( - naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) op.create_table( "some_table", Column('x', Integer), @@ -113,10 +117,11 @@ def test_add_check_constraint_inline_on_table_w_f(): "(x INTEGER, CONSTRAINT ck_some_table_cc1 CHECK (im a constraint))" ) + @requires_094 def test_add_check_constraint_inline_on_column(): context = op_fixture( - naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) op.create_table( "some_table", Column('x', Integer, CheckConstraint("im a constraint", name="cc1")) @@ -126,10 +131,11 @@ def test_add_check_constraint_inline_on_column(): "(x INTEGER CONSTRAINT ck_some_table_cc1 CHECK (im a constraint))" ) + @requires_094 def test_add_check_constraint_inline_on_column_w_f(): context = op_fixture( - naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) + naming_convention={"ck": "ck_%(table_name)s_%(constraint_name)s"}) op.create_table( "some_table", Column('x', Integer, CheckConstraint("im a constraint", name=op.f("ck_q_cc1"))) @@ -143,8 +149,8 @@ def test_add_check_constraint_inline_on_column_w_f(): @requires_094 def test_add_column_schema_type(): context = op_fixture(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s" - }) + "ck": "ck_%(table_name)s_%(constraint_name)s" + }) op.add_column('t1', Column('c1', Boolean(name='foo'), nullable=False)) context.assert_( 'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL', @@ -155,12 +161,10 @@ def test_add_column_schema_type(): @requires_094 def test_add_column_schema_type_w_f(): context = op_fixture(naming_convention={ - "ck": "ck_%(table_name)s_%(constraint_name)s" - }) + "ck": "ck_%(table_name)s_%(constraint_name)s" + }) op.add_column('t1', Column('c1', Boolean(name=op.f('foo')), nullable=False)) context.assert_( 'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN NOT NULL', 'ALTER TABLE t1 ADD CONSTRAINT foo CHECK (c1 IN (0, 1))' ) - - diff --git a/tests/test_oracle.py b/tests/test_oracle.py index d443a71..781a1ab 100644 --- a/tests/test_oracle.py +++ b/tests/test_oracle.py @@ -11,6 +11,7 @@ from . import op_fixture, capture_context_buffer, \ class FullEnvironmentTests(TestCase): + @classmethod def setup_class(cls): env = staging_env() @@ -40,13 +41,14 @@ class FullEnvironmentTests(TestCase): command.upgrade(self.cfg, self.a, sql=True) assert "BYE" in buf.getvalue() + class OpTest(TestCase): + def test_add_column(self): context = op_fixture('oracle') op.add_column('t1', Column('c1', Integer, nullable=False)) context.assert_("ALTER TABLE t1 ADD c1 INTEGER NOT NULL") - def test_add_column_with_default(self): context = op_fixture("oracle") op.add_column('t1', Column('c1', Integer, nullable=False, server_default="12")) @@ -147,10 +149,9 @@ class OpTest(TestCase): ) # TODO: when we add schema support - #def test_alter_column_rename_oracle_schema(self): + # def test_alter_column_rename_oracle_schema(self): # context = op_fixture('oracle') # op.alter_column("t", "c", name="x", schema="y") # context.assert_( # 'ALTER TABLE y.t RENAME COLUMN c TO c2' # ) - diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 2e0965e..4cd160b 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -1,7 +1,7 @@ from unittest import TestCase from sqlalchemy import DateTime, MetaData, Table, Column, text, Integer, \ - String, Interval + String, Interval from sqlalchemy.dialects.postgresql import ARRAY from sqlalchemy.schema import DefaultClause from sqlalchemy.engine.reflection import Inspector @@ -13,10 +13,12 @@ from alembic import command, util from alembic.migration import MigrationContext from alembic.script import ScriptDirectory from . import db_for_dialect, eq_, staging_env, \ - clear_staging_env, _no_sql_testing_config,\ - capture_context_buffer, requires_09, write_script + clear_staging_env, _no_sql_testing_config,\ + capture_context_buffer, requires_09, write_script + class PGOfflineEnumTest(TestCase): + def setUp(self): staging_env() self.cfg = cfg = _no_sql_testing_config() @@ -29,7 +31,6 @@ class PGOfflineEnumTest(TestCase): def tearDown(self): clear_staging_env() - def _inline_enum_script(self): write_script(self.script, self.rid, """ revision = '%s' @@ -103,6 +104,7 @@ def downgrade(): class PostgresqlInlineLiteralTest(TestCase): + @classmethod def setup_class(cls): cls.bind = db_for_dialect("postgresql") @@ -144,7 +146,9 @@ class PostgresqlInlineLiteralTest(TestCase): 1, ) + class PostgresqlDefaultCompareTest(TestCase): + @classmethod def setup_class(cls): cls.bind = db_for_dialect("postgresql") @@ -180,19 +184,19 @@ class PostgresqlDefaultCompareTest(TestCase): alternate = orig_default t1 = Table("test", self.metadata, - Column("somecol", type_, server_default=orig_default)) + Column("somecol", type_, server_default=orig_default)) t2 = Table("test", MetaData(), - Column("somecol", type_, server_default=alternate)) + Column("somecol", type_, server_default=alternate)) t1.create(self.bind) insp = Inspector.from_engine(self.bind) cols = insp.get_columns(t1.name) insp_col = Column("somecol", cols[0]['type'], - server_default=text(cols[0]['default'])) + server_default=text(cols[0]['default'])) diffs = [] _compare_server_default(None, "test", "somecol", insp_col, - t2.c.somecol, diffs, self.autogen_context) + t2.c.somecol, diffs, self.autogen_context) eq_(bool(diffs), diff_expected) def _compare_default( @@ -284,11 +288,11 @@ class PostgresqlDefaultCompareTest(TestCase): def test_primary_key_skip(self): """Test that SERIAL cols are just skipped""" t1 = Table("sometable", self.metadata, - Column("id", Integer, primary_key=True) - ) + Column("id", Integer, primary_key=True) + ) t2 = Table("sometable", MetaData(), - Column("id", Integer, primary_key=True) - ) + Column("id", Integer, primary_key=True) + ) assert not self._compare_default( t1, t2, t2.c.id, "" ) diff --git a/tests/test_revision_create.py b/tests/test_revision_create.py index 5bf12cf..cbe2a6e 100644 --- a/tests/test_revision_create.py +++ b/tests/test_revision_create.py @@ -10,7 +10,9 @@ import datetime env, abc, def_ = None, None, None + class GeneralOrderedTests(unittest.TestCase): + def test_001_environment(self): assert_set = set(['env.py', 'script.py.mako', 'README']) eq_( @@ -76,28 +78,26 @@ class GeneralOrderedTests(unittest.TestCase): def test_008_long_name(self): rid = util.rev_id() env.generate_revision(rid, - "this is a really long name with " - "lots of characters and also " - "I'd like it to\nhave\nnewlines") + "this is a really long name with " + "lots of characters and also " + "I'd like it to\nhave\nnewlines") assert os.access( os.path.join(env.dir, 'versions', - '%s_this_is_a_really_long_name_with_lots_of_.py' % rid), - os.F_OK) - + '%s_this_is_a_really_long_name_with_lots_of_.py' % rid), + os.F_OK) def test_009_long_name_configurable(self): env.truncate_slug_length = 60 rid = util.rev_id() env.generate_revision(rid, - "this is a really long name with " - "lots of characters and also " - "I'd like it to\nhave\nnewlines") + "this is a really long name with " + "lots of characters and also " + "I'd like it to\nhave\nnewlines") assert os.access( os.path.join(env.dir, 'versions', - '%s_this_is_a_really_long_name_with_lots_' - 'of_characters_and_also_.py' % rid), - os.F_OK) - + '%s_this_is_a_really_long_name_with_lots_' + 'of_characters_and_also_.py' % rid), + os.F_OK) @classmethod def setup_class(cls): @@ -108,7 +108,9 @@ class GeneralOrderedTests(unittest.TestCase): def teardown_class(cls): clear_staging_env() + class ScriptNamingTest(unittest.TestCase): + @classmethod def setup_class(cls): _testing_config() @@ -119,12 +121,12 @@ class ScriptNamingTest(unittest.TestCase): def test_args(self): script = ScriptDirectory( - staging_directory, - file_template="%(rev)s_%(slug)s_" - "%(year)s_%(month)s_" - "%(day)s_%(hour)s_" - "%(minute)s_%(second)s" - ) + staging_directory, + file_template="%(rev)s_%(slug)s_" + "%(year)s_%(month)s_" + "%(day)s_%(hour)s_" + "%(minute)s_%(second)s" + ) create_date = datetime.datetime(2012, 7, 25, 15, 8, 5) eq_( script._rev_path("12345", "this is a message", create_date), @@ -134,6 +136,7 @@ class ScriptNamingTest(unittest.TestCase): class TemplateArgsTest(unittest.TestCase): + def setUp(self): staging_env() self.cfg = _no_sql_testing_config( @@ -153,7 +156,7 @@ class TemplateArgsTest(unittest.TestCase): template_args=template_args ) env.configure(dialect_name="sqlite", - template_args={"y": "y2", "q": "q1"}) + template_args={"y": "y2", "q": "q1"}) eq_( template_args, {"x": "x1", "y": "y2", "z": "z1", "q": "q1"} @@ -206,4 +209,3 @@ down_revision = ${repr(down_revision)} with open(rev.path) as f: text = f.read() assert "somearg: somevalue" in text - diff --git a/tests/test_revision_paths.py b/tests/test_revision_paths.py index 5a02189..15da250 100644 --- a/tests/test_revision_paths.py +++ b/tests/test_revision_paths.py @@ -6,6 +6,7 @@ env = None a, b, c, d, e = None, None, None, None, None cfg = None + def setup(): global env env = staging_env() @@ -16,6 +17,7 @@ def setup(): d = env.generate_revision(util.rev_id(), 'c->d', refresh=True) e = env.generate_revision(util.rev_id(), 'd->e', refresh=True) + def teardown(): clear_staging_env() @@ -39,6 +41,7 @@ def test_upgrade_path(): ] ) + def test_relative_upgrade_path(): eq_( env._upgrade_revs("+2", a.revision), @@ -64,6 +67,7 @@ def test_relative_upgrade_path(): ] ) + def test_invalid_relative_upgrade_path(): assert_raises_message( util.CommandError, @@ -77,6 +81,7 @@ def test_invalid_relative_upgrade_path(): env._upgrade_revs, "+5", b.revision ) + def test_downgrade_path(): eq_( @@ -96,6 +101,7 @@ def test_downgrade_path(): ] ) + def test_relative_downgrade_path(): eq_( env._downgrade_revs("-1", c.revision), @@ -113,6 +119,7 @@ def test_relative_downgrade_path(): ] ) + def test_invalid_relative_downgrade_path(): assert_raises_message( util.CommandError, @@ -126,6 +133,7 @@ def test_invalid_relative_downgrade_path(): env._downgrade_revs, "+2", b.revision ) + def test_invalid_move_rev_to_none(): assert_raises_message( util.CommandError, @@ -133,10 +141,10 @@ def test_invalid_move_rev_to_none(): env._downgrade_revs, b.revision[0:3], None ) + def test_invalid_move_higher_to_lower(): assert_raises_message( - util.CommandError, + util.CommandError, "Revision %s is not an ancestor of %s" % (c.revision, b.revision), env._downgrade_revs, c.revision[0:4], b.revision ) - diff --git a/tests/test_sql_script.py b/tests/test_sql_script.py index 7aae797..ba64df7 100644 --- a/tests/test_sql_script.py +++ b/tests/test_sql_script.py @@ -14,6 +14,7 @@ import re cfg = None a, b, c = None, None, None + class ThreeRevTest(unittest.TestCase): def setUp(self): @@ -32,11 +33,11 @@ class ThreeRevTest(unittest.TestCase): with capture_context_buffer(transactional_ddl=True) as buf: command.upgrade(cfg, c, sql=True) assert re.match( - (r"^BEGIN;\s+CREATE TABLE.*?%s.*" % a) + - (r".*%s" % b) + - (r".*%s.*?COMMIT;.*$" % c), + (r"^BEGIN;\s+CREATE TABLE.*?%s.*" % a) + + (r".*%s" % b) + + (r".*%s.*?COMMIT;.*$" % c), - buf.getvalue(), re.S) + buf.getvalue(), re.S) def test_begin_commit_nontransactional_ddl(self): with capture_context_buffer(transactional_ddl=False) as buf: @@ -48,11 +49,11 @@ class ThreeRevTest(unittest.TestCase): with capture_context_buffer(transaction_per_migration=True) as buf: command.upgrade(cfg, c, sql=True) assert re.match( - (r"^BEGIN;\s+CREATE TABLE.*%s.*?COMMIT;.*" % a) + - (r"BEGIN;.*?%s.*?COMMIT;.*" % b) + - (r"BEGIN;.*?%s.*?COMMIT;.*$" % c), + (r"^BEGIN;\s+CREATE TABLE.*%s.*?COMMIT;.*" % a) + + (r"BEGIN;.*?%s.*?COMMIT;.*" % b) + + (r"BEGIN;.*?%s.*?COMMIT;.*$" % c), - buf.getvalue(), re.S) + buf.getvalue(), re.S) def test_version_from_none_insert(self): with capture_context_buffer() as buf: @@ -99,6 +100,7 @@ class ThreeRevTest(unittest.TestCase): class EncodingTest(unittest.TestCase): + def setUp(self): global cfg, env, a env = staging_env() @@ -128,8 +130,8 @@ def downgrade(): def test_encode(self): with capture_context_buffer( - bytes_io=True, - output_encoding='utf-8' - ) as buf: + bytes_io=True, + output_encoding='utf-8' + ) as buf: command.upgrade(cfg, a, sql=True) assert "« S’il vous plaît…".encode("utf-8") in buf.getvalue() diff --git a/tests/test_sqlite.py b/tests/test_sqlite.py index 9ceb78e..ea9411d 100644 --- a/tests/test_sqlite.py +++ b/tests/test_sqlite.py @@ -1,8 +1,9 @@ from tests import op_fixture, assert_raises_message from alembic import op -from sqlalchemy import Integer, Column, Boolean +from sqlalchemy import Integer, Column, Boolean from sqlalchemy.sql import column + def test_add_column(): context = op_fixture('sqlite') op.add_column('t1', Column('c1', Integer)) @@ -10,6 +11,7 @@ def test_add_column(): 'ALTER TABLE t1 ADD COLUMN c1 INTEGER' ) + def test_add_column_implicit_constraint(): context = op_fixture('sqlite') op.add_column('t1', Column('c1', Boolean)) @@ -17,6 +19,7 @@ def test_add_column_implicit_constraint(): 'ALTER TABLE t1 ADD COLUMN c1 BOOLEAN' ) + def test_add_explicit_constraint(): context = op_fixture('sqlite') assert_raises_message( @@ -28,6 +31,7 @@ def test_add_explicit_constraint(): column('name') > 5 ) + def test_drop_explicit_constraint(): context = op_fixture('sqlite') assert_raises_message( @@ -37,4 +41,3 @@ def test_drop_explicit_constraint(): "foo", "sometable", ) - diff --git a/tests/test_version_table.py b/tests/test_version_table.py index 3a0a54d..98dec50 100644 --- a/tests/test_version_table.py +++ b/tests/test_version_table.py @@ -8,6 +8,7 @@ from alembic.util import CommandError version_table = Table('version_table', MetaData(), Column('version_num', String(32), nullable=False)) + class TestMigrationContext(unittest.TestCase): _bind = [] diff --git a/tests/test_versioning.py b/tests/test_versioning.py index 68440fc..7c59a12 100644 --- a/tests/test_versioning.py +++ b/tests/test_versioning.py @@ -7,6 +7,7 @@ from . import clear_staging_env, staging_env, \ _sqlite_testing_config, sqlite_db, eq_, write_script, \ assert_raises_message + class VersioningTest(unittest.TestCase): sourceless = False @@ -62,7 +63,6 @@ class VersioningTest(unittest.TestCase): """ % (c, b), sourceless=self.sourceless) - def test_002_upgrade(self): command.upgrade(self.cfg, c) db = sqlite_db() @@ -94,7 +94,6 @@ class VersioningTest(unittest.TestCase): def test_006_upgrade_again(self): command.upgrade(self.cfg, b) - # TODO: test some invalid movements @classmethod @@ -106,7 +105,9 @@ class VersioningTest(unittest.TestCase): def teardown_class(cls): clear_staging_env() + class VersionNameTemplateTest(unittest.TestCase): + def setUp(self): self.env = staging_env() self.cfg = _sqlite_testing_config() @@ -188,7 +189,9 @@ class VersionNameTemplateTest(unittest.TestCase): class SourcelessVersioningTest(VersioningTest): sourceless = True + class SourcelessNeedsFlagTest(unittest.TestCase): + def setUp(self): self.env = staging_env(sourceless=False) self.cfg = _sqlite_testing_config() |