diff options
-rw-r--r-- | docs/changelog.rst | 1 | ||||
-rw-r--r-- | migrate/changeset/__init__.py | 12 | ||||
-rw-r--r-- | migrate/changeset/ansisql.py | 165 | ||||
-rw-r--r-- | migrate/changeset/constraint.py | 18 | ||||
-rw-r--r-- | migrate/changeset/databases/firebird.py | 5 | ||||
-rw-r--r-- | migrate/changeset/databases/mysql.py | 66 | ||||
-rw-r--r-- | migrate/changeset/databases/oracle.py | 9 | ||||
-rw-r--r-- | migrate/changeset/databases/postgres.py | 9 | ||||
-rw-r--r-- | migrate/changeset/databases/sqlite.py | 11 | ||||
-rw-r--r-- | migrate/changeset/databases/visitor.py | 22 | ||||
-rw-r--r-- | migrate/changeset/schema.py | 10 | ||||
-rw-r--r-- | migrate/versioning/genmodel.py | 23 | ||||
-rw-r--r-- | migrate/versioning/schemadiff.py | 33 | ||||
-rw-r--r-- | migrate/versioning/script/py.py | 2 | ||||
-rw-r--r-- | test/versioning/test_shell.py | 20 |
15 files changed, 252 insertions, 154 deletions
diff --git a/docs/changelog.rst b/docs/changelog.rst index 58e6ec3..89a8238 100644 --- a/docs/changelog.rst +++ b/docs/changelog.rst @@ -1,6 +1,7 @@ 0.5.5 ----- +- added support for SQLAlchemy 0.6 (missing oracle and firebird) by Michael Bayer - alter, create, drop column / rename table / rename index constructs now accept `alter_metadata` parameter. If True, it will modify Column/Table objects according to changes. Otherwise, everything will be untouched. - complete refactoring of :class:`~migrate.changeset.schema.ColumnDelta` (fixes issue 23) - added support for :ref:`firebird <firebird-d>` diff --git a/migrate/changeset/__init__.py b/migrate/changeset/__init__.py index 940c23f..25fc73b 100644 --- a/migrate/changeset/__init__.py +++ b/migrate/changeset/__init__.py @@ -6,9 +6,21 @@ """ import sqlalchemy +from sqlalchemy import __version__ as _sa_version +import re + +_sa_version = tuple(int(re.match("\d+", x).group(0)) for x in _sa_version.split(".")) +SQLA_06 = _sa_version >= (0, 6) + +del re +del _sa_version + from migrate.changeset.schema import * from migrate.changeset.constraint import * + + + sqlalchemy.schema.Table.__bases__ += (ChangesetTable, ) sqlalchemy.schema.Column.__bases__ += (ChangesetColumn, ) sqlalchemy.schema.Index.__bases__ += (ChangesetIndex, ) diff --git a/migrate/changeset/ansisql.py b/migrate/changeset/ansisql.py index f0dfed5..ee585a7 100644 --- a/migrate/changeset/ansisql.py +++ b/migrate/changeset/ansisql.py @@ -5,23 +5,53 @@ things that just happen to work with multiple databases. """ import sqlalchemy as sa +from sqlalchemy.schema import SchemaVisitor from sqlalchemy.engine.default import DefaultDialect from sqlalchemy.schema import (ForeignKeyConstraint, PrimaryKeyConstraint, CheckConstraint, UniqueConstraint, Index) -from sqlalchemy.sql.compiler import SchemaGenerator, SchemaDropper -from migrate.changeset import exceptions, constraint +from migrate.changeset import exceptions, constraint, SQLA_06 +import StringIO +if not SQLA_06: + from sqlalchemy.sql.compiler import SchemaGenerator, SchemaDropper +else: + from sqlalchemy.schema import AddConstraint, DropConstraint + from sqlalchemy.sql.compiler import DDLCompiler + SchemaGenerator = SchemaDropper = DDLCompiler -SchemaIterator = sa.engine.SchemaIterator +class AlterTableVisitor(SchemaVisitor): + """Common operations for ``ALTER TABLE`` statements.""" + def append(self, s): + """Append content to the SchemaIterator's query buffer.""" -class AlterTableVisitor(SchemaIterator): - """Common operations for ``ALTER TABLE`` statements.""" + self.buffer.write(s) + + def execute(self): + """Execute the contents of the SchemaIterator's buffer.""" + + try: + return self.connection.execute(self.buffer.getvalue()) + finally: + self.buffer.truncate(0) + + def __init__(self, dialect, connection, **kw): + self.connection = connection + self.buffer = StringIO.StringIO() + self.preparer = dialect.identifier_preparer + self.dialect = dialect + def traverse_single(self, elem): + ret = super(AlterTableVisitor, self).traverse_single(elem) + if ret: + # adapt to 0.6 which uses a string-returning + # object + self.append(ret) + def _to_table(self, param): """Returns the table object for the given param object.""" if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)): @@ -88,6 +118,9 @@ class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator): name=column.primary_key_name) cons.create() + if SQLA_06: + def add_foreignkey(self, fk): + self.connection.execute(AddConstraint(fk)) class ANSIColumnDropper(AlterTableVisitor, SchemaDropper): """Extends ANSI SQL dropper for column dropping (``ALTER TABLE @@ -181,7 +214,10 @@ class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator): def _visit_column_type(self, table, column, delta): type_ = delta['type'] - type_text = type_.dialect_impl(self.dialect).get_col_spec() + if SQLA_06: + type_text = str(type_.compile(dialect=self.dialect)) + else: + type_text = type_.dialect_impl(self.dialect).get_col_spec() self.append("TYPE %s" % type_text) def _visit_column_name(self, table, column, delta): @@ -225,60 +261,75 @@ class ANSIConstraintCommon(AlterTableVisitor): def visit_migrate_unique_constraint(self, *p, **k): self._visit_constraint(*p, **k) +if SQLA_06: + class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator): + def _visit_constraint(self, constraint): + constraint.name = self.get_constraint_name(constraint) + self.append(self.process(AddConstraint(constraint))) + self.execute() -class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator): + class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper): + def _visit_constraint(self, constraint): + constraint.name = self.get_constraint_name(constraint) + self.append(self.process(DropConstraint(constraint, cascade=constraint.cascade))) + self.execute() - def get_constraint_specification(self, cons, **kwargs): - """Constaint SQL generators. - - We cannot use SA visitors because they append comma. - """ - if isinstance(cons, PrimaryKeyConstraint): - if cons.name is not None: - self.append("CONSTRAINT %s " % self.preparer.format_constraint(cons)) - self.append("PRIMARY KEY ") - self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) - for c in cons)) - self.define_constraint_deferrability(cons) - elif isinstance(cons, ForeignKeyConstraint): - self.define_foreign_key(cons) - elif isinstance(cons, CheckConstraint): - if cons.name is not None: - self.append("CONSTRAINT %s " % - self.preparer.format_constraint(cons)) - self.append("CHECK (%s)" % cons.sqltext) - self.define_constraint_deferrability(cons) - elif isinstance(cons, UniqueConstraint): - if cons.name is not None: - self.append("CONSTRAINT %s " % - self.preparer.format_constraint(cons)) - self.append("UNIQUE (%s)" % \ - (', '.join(self.preparer.quote(c.name, c.quote) for c in cons))) - self.define_constraint_deferrability(cons) - else: - raise exceptions.InvalidConstraintError(cons) +else: + class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator): - def _visit_constraint(self, constraint): - table = self.start_alter_table(constraint) - constraint.name = self.get_constraint_name(constraint) - self.append("ADD ") - self.get_constraint_specification(constraint) - self.execute() - - -class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper): - - def _visit_constraint(self, constraint): - self.start_alter_table(constraint) - self.append("DROP CONSTRAINT ") - constraint.name = self.get_constraint_name(constraint) - self.append(self.preparer.format_constraint(constraint)) - if constraint.cascade: - self.cascade_constraint(constraint) - self.execute() - - def cascade_constraint(self, constraint): - self.append(" CASCADE") + def get_constraint_specification(self, cons, **kwargs): + """Constaint SQL generators. + + We cannot use SA visitors because they append comma. + """ + + if isinstance(cons, PrimaryKeyConstraint): + if cons.name is not None: + self.append("CONSTRAINT %s " % self.preparer.format_constraint(cons)) + self.append("PRIMARY KEY ") + self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote) + for c in cons)) + self.define_constraint_deferrability(cons) + elif isinstance(cons, ForeignKeyConstraint): + self.define_foreign_key(cons) + elif isinstance(cons, CheckConstraint): + if cons.name is not None: + self.append("CONSTRAINT %s " % + self.preparer.format_constraint(cons)) + self.append("CHECK (%s)" % cons.sqltext) + self.define_constraint_deferrability(cons) + elif isinstance(cons, UniqueConstraint): + if cons.name is not None: + self.append("CONSTRAINT %s " % + self.preparer.format_constraint(cons)) + self.append("UNIQUE (%s)" % \ + (', '.join(self.preparer.quote(c.name, c.quote) for c in cons))) + self.define_constraint_deferrability(cons) + else: + raise exceptions.InvalidConstraintError(cons) + + def _visit_constraint(self, constraint): + + table = self.start_alter_table(constraint) + constraint.name = self.get_constraint_name(constraint) + self.append("ADD ") + self.get_constraint_specification(constraint) + self.execute() + + + class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper): + + def _visit_constraint(self, constraint): + self.start_alter_table(constraint) + self.append("DROP CONSTRAINT ") + constraint.name = self.get_constraint_name(constraint) + self.append(self.preparer.format_constraint(constraint)) + if constraint.cascade: + self.cascade_constraint(constraint) + self.execute() + + def cascade_constraint(self, constraint): + self.append(" CASCADE") class ANSIDialect(DefaultDialect): diff --git a/migrate/changeset/constraint.py b/migrate/changeset/constraint.py index b09b30a..72251f5 100644 --- a/migrate/changeset/constraint.py +++ b/migrate/changeset/constraint.py @@ -4,7 +4,7 @@ from sqlalchemy import schema from migrate.changeset.exceptions import * - +from migrate.changeset import SQLA_06 class ConstraintChangeset(object): """Base class for Constraint classes.""" @@ -54,7 +54,10 @@ class ConstraintChangeset(object): """ self.cascade = kw.pop('cascade', False) self.__do_imports('constraintdropper', *a, **kw) - self.columns.clear() + # the spirit of Constraint objects is that they + # are immutable (just like in a DB. they're only ADDed + # or DROPped). + #self.columns.clear() return self @@ -69,7 +72,7 @@ class PrimaryKeyConstraint(ConstraintChangeset, schema.PrimaryKeyConstraint): :type cols: strings or Column instances """ - __visit_name__ = 'migrate_primary_key_constraint' + __migrate_visit_name__ = 'migrate_primary_key_constraint' def __init__(self, *cols, **kwargs): colnames, table = self._normalize_columns(cols) @@ -97,7 +100,7 @@ class ForeignKeyConstraint(ConstraintChangeset, schema.ForeignKeyConstraint): :type refcolumns: list of strings or Column instances """ - __visit_name__ = 'migrate_foreign_key_constraint' + __migrate_visit_name__ = 'migrate_foreign_key_constraint' def __init__(self, columns, refcolumns, *args, **kwargs): colnames, table = self._normalize_columns(columns) @@ -139,7 +142,7 @@ class CheckConstraint(ConstraintChangeset, schema.CheckConstraint): :type sqltext: string """ - __visit_name__ = 'migrate_check_constraint' + __migrate_visit_name__ = 'migrate_check_constraint' def __init__(self, sqltext, *args, **kwargs): cols = kwargs.pop('columns', []) @@ -150,7 +153,8 @@ class CheckConstraint(ConstraintChangeset, schema.CheckConstraint): table = kwargs.pop('table', table) schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs) if table is not None: - self.table = table + if not SQLA_06: + self.table = table self._set_parent(table) self.colnames = colnames @@ -172,7 +176,7 @@ class UniqueConstraint(ConstraintChangeset, schema.UniqueConstraint): .. versionadded:: 0.5.5 """ - __visit_name__ = 'migrate_unique_constraint' + __migrate_visit_name__ = 'migrate_unique_constraint' def __init__(self, *cols, **kwargs): self.colnames, table = self._normalize_columns(cols) diff --git a/migrate/changeset/databases/firebird.py b/migrate/changeset/databases/firebird.py index d60cf00..5eacd58 100644 --- a/migrate/changeset/databases/firebird.py +++ b/migrate/changeset/databases/firebird.py @@ -1,14 +1,13 @@ """ Firebird database specific implementations of changeset classes. """ -from sqlalchemy.databases import firebird as sa_base from migrate.changeset import ansisql, exceptions - +# TODO: SQLA 0.6 has not migrated the FB dialect over yet +from sqlalchemy.databases import firebird as sa_base FBSchemaGenerator = sa_base.FBSchemaGenerator - class FBColumnGenerator(FBSchemaGenerator, ansisql.ANSIColumnGenerator): """Firebird column generator implementation.""" diff --git a/migrate/changeset/databases/mysql.py b/migrate/changeset/databases/mysql.py index 5b5a16e..6655a42 100644 --- a/migrate/changeset/databases/mysql.py +++ b/migrate/changeset/databases/mysql.py @@ -2,13 +2,13 @@ MySQL database specific implementations of changeset classes. """ +from migrate.changeset import ansisql, exceptions, SQLA_06 from sqlalchemy.databases import mysql as sa_base -from migrate.changeset import ansisql, exceptions - - -MySQLSchemaGenerator = sa_base.MySQLSchemaGenerator - +if not SQLA_06: + MySQLSchemaGenerator = sa_base.MySQLSchemaGenerator +else: + MySQLSchemaGenerator = sa_base.MySQLDDLCompiler class MySQLColumnGenerator(MySQLSchemaGenerator, ansisql.ANSIColumnGenerator): pass @@ -39,31 +39,37 @@ class MySQLSchemaChanger(MySQLSchemaGenerator, ansisql.ANSISchemaChanger): class MySQLConstraintGenerator(ansisql.ANSIConstraintGenerator): pass - -class MySQLConstraintDropper(ansisql.ANSIConstraintDropper): - - def visit_migrate_primary_key_constraint(self, constraint): - self.start_alter_table(constraint) - self.append("DROP PRIMARY KEY") - self.execute() - - def visit_migrate_foreign_key_constraint(self, constraint): - self.start_alter_table(constraint) - self.append("DROP FOREIGN KEY ") - constraint.name = self.get_constraint_name(constraint) - self.append(self.preparer.format_constraint(constraint)) - self.execute() - - def visit_migrate_check_constraint(self, *p, **k): - raise exceptions.NotSupportedError("MySQL does not support CHECK" - " constraints, use triggers instead.") - - def visit_migrate_unique_constraint(self, constraint, *p, **k): - self.start_alter_table(constraint) - self.append('DROP INDEX ') - constraint.name = self.get_constraint_name(constraint) - self.append(self.preparer.format_constraint(constraint)) - self.execute() +if SQLA_06: + class MySQLConstraintDropper(MySQLSchemaGenerator, ansisql.ANSIConstraintDropper): + def visit_migrate_check_constraint(self, *p, **k): + raise exceptions.NotSupportedError("MySQL does not support CHECK" + " constraints, use triggers instead.") + +else: + class MySQLConstraintDropper(ansisql.ANSIConstraintDropper): + + def visit_migrate_primary_key_constraint(self, constraint): + self.start_alter_table(constraint) + self.append("DROP PRIMARY KEY") + self.execute() + + def visit_migrate_foreign_key_constraint(self, constraint): + self.start_alter_table(constraint) + self.append("DROP FOREIGN KEY ") + constraint.name = self.get_constraint_name(constraint) + self.append(self.preparer.format_constraint(constraint)) + self.execute() + + def visit_migrate_check_constraint(self, *p, **k): + raise exceptions.NotSupportedError("MySQL does not support CHECK" + " constraints, use triggers instead.") + + def visit_migrate_unique_constraint(self, constraint, *p, **k): + self.start_alter_table(constraint) + self.append('DROP INDEX ') + constraint.name = self.get_constraint_name(constraint) + self.append(self.preparer.format_constraint(constraint)) + self.execute() class MySQLDialect(ansisql.ANSIDialect): diff --git a/migrate/changeset/databases/oracle.py b/migrate/changeset/databases/oracle.py index 93c9f8f..fd2749a 100644 --- a/migrate/changeset/databases/oracle.py +++ b/migrate/changeset/databases/oracle.py @@ -2,12 +2,17 @@ Oracle database specific implementations of changeset classes. """ import sqlalchemy as sa -from sqlalchemy.databases import oracle as sa_base from migrate.changeset import ansisql, exceptions +from sqlalchemy.databases import oracle as sa_base + +from migrate.changeset import ansisql, exceptions, SQLA_06 -OracleSchemaGenerator = sa_base.OracleSchemaGenerator +if not SQLA_06: + OracleSchemaGenerator = sa_base.OracleSchemaGenerator +else: + OracleSchemaGenerator = sa_base.OracleDDLCompiler class OracleColumnGenerator(OracleSchemaGenerator, ansisql.ANSIColumnGenerator): diff --git a/migrate/changeset/databases/postgres.py b/migrate/changeset/databases/postgres.py index bcdc08b..2c36ed1 100644 --- a/migrate/changeset/databases/postgres.py +++ b/migrate/changeset/databases/postgres.py @@ -3,12 +3,13 @@ .. _`PostgreSQL`: http://www.postgresql.org/ """ -from migrate.changeset import ansisql +from migrate.changeset import ansisql, SQLA_06 from sqlalchemy.databases import postgres as sa_base -#import sqlalchemy as sa - -PGSchemaGenerator = sa_base.PGSchemaGenerator +if not SQLA_06: + PGSchemaGenerator = sa_base.PGSchemaGenerator +else: + PGSchemaGenerator = sa_base.PGDDLCompiler class PGColumnGenerator(PGSchemaGenerator, ansisql.ANSIColumnGenerator): diff --git a/migrate/changeset/databases/sqlite.py b/migrate/changeset/databases/sqlite.py index 59902b4..64be9bf 100644 --- a/migrate/changeset/databases/sqlite.py +++ b/migrate/changeset/databases/sqlite.py @@ -8,10 +8,12 @@ from copy import copy from sqlalchemy.databases import sqlite as sa_base -from migrate.changeset import ansisql, exceptions +from migrate.changeset import ansisql, exceptions, SQLA_06 - -SQLiteSchemaGenerator = sa_base.SQLiteSchemaGenerator +if not SQLA_06: + SQLiteSchemaGenerator = sa_base.SQLiteSchemaGenerator +else: + SQLiteSchemaGenerator = sa_base.SQLiteDDLCompiler class SQLiteCommon(object): @@ -52,8 +54,7 @@ class SQLiteHelper(SQLiteCommon): table.indexes = ixbackup table.constraints = consbackup - -class SQLiteColumnGenerator(SQLiteSchemaGenerator, SQLiteCommon, +class SQLiteColumnGenerator(SQLiteSchemaGenerator, SQLiteCommon, ansisql.ANSIColumnGenerator): """SQLite ColumnGenerator""" diff --git a/migrate/changeset/databases/visitor.py b/migrate/changeset/databases/visitor.py index 18f1ac0..6db2d51 100644 --- a/migrate/changeset/databases/visitor.py +++ b/migrate/changeset/databases/visitor.py @@ -13,12 +13,12 @@ from migrate.changeset.databases import (sqlite, # Map SA dialects to the corresponding Migrate extensions DIALECTS = { - sa.engine.default.DefaultDialect: ansisql.ANSIDialect, - sa.databases.sqlite.SQLiteDialect: sqlite.SQLiteDialect, - sa.databases.postgres.PGDialect: postgres.PGDialect, - sa.databases.mysql.MySQLDialect: mysql.MySQLDialect, - sa.databases.oracle.OracleDialect: oracle.OracleDialect, - sa.databases.firebird.FBDialect: firebird.FBDialect, + "default": ansisql.ANSIDialect, + "sqlite": sqlite.SQLiteDialect, + "postgres": postgres.PGDialect, + "mysql": mysql.MySQLDialect, + "oracle": oracle.OracleDialect, + "firebird": firebird.FBDialect, } @@ -47,8 +47,8 @@ def get_dialect_visitor(sa_dialect, name): """ # map sa dialect to migrate dialect and return visitor - sa_dialect_cls = sa_dialect.__class__ - migrate_dialect_cls = DIALECTS[sa_dialect_cls] + sa_dialect_name = getattr(sa_dialect, 'name', 'default') + migrate_dialect_cls = DIALECTS[sa_dialect_name] visitor = getattr(migrate_dialect_cls, name) # bind preparer @@ -61,6 +61,10 @@ def run_single_visitor(engine, visitorcallable, element, **kwargs): conn = engine.contextual_connect(close_with_result=False) try: visitor = visitorcallable(engine.dialect, conn) - getattr(visitor, 'visit_' + element.__visit_name__)(element, **kwargs) + if hasattr(element, '__migrate_visit_name__'): + fn = getattr(visitor, 'visit_' + element.__migrate_visit_name__) + else: + fn = getattr(visitor, 'visit_' + element.__visit_name__) + fn(element, **kwargs) finally: conn.close() diff --git a/migrate/changeset/schema.py b/migrate/changeset/schema.py index 3ae9b46..ab839b1 100644 --- a/migrate/changeset/schema.py +++ b/migrate/changeset/schema.py @@ -4,6 +4,7 @@ from UserDict import DictMixin import sqlalchemy +from migrate.changeset import SQLA_06 from migrate.changeset.exceptions import * from migrate.changeset.databases.visitor import (get_engine_visitor, run_single_visitor) @@ -310,7 +311,7 @@ class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem): def process_column(self, column): """Processes default values for column""" # XXX: this is a snippet from SA processing of positional parameters - if column.args: + if not SQLA_06 and column.args: toinit = list(column.args) else: toinit = list() @@ -328,7 +329,9 @@ class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem): for_update=True)) if toinit: column._init_items(*toinit) - column.args = [] + + if not SQLA_06: + column.args = [] def _get_table(self): return getattr(self, '_table', None) @@ -365,9 +368,6 @@ class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem): self.current_name = column.name if self.alter_metadata: self._result_column = column - # remove column from table, nothing has changed yet - if self.table: - column.remove_from_table(self.table) else: self._result_column = column.copy_fixed() diff --git a/migrate/versioning/genmodel.py b/migrate/versioning/genmodel.py index 91f3976..ba455b0 100644 --- a/migrate/versioning/genmodel.py +++ b/migrate/versioning/genmodel.py @@ -34,10 +34,7 @@ class ModelGenerator(object): def __init__(self, diff, declarative=False): self.diff = diff self.declarative = declarative - # is there an easier way to get this? - dialectModule = sys.modules[self.diff.conn.dialect.__module__] - self.colTypeMappings = dict((v, k) for k, v in \ - dialectModule.colspecs.items()) + def column_repr(self, col): kwarg = [] @@ -63,18 +60,18 @@ class ModelGenerator(object): # crs: not sure if this is good idea, but it gets rid of extra # u'' name = col.name.encode('utf8') - type = self.colTypeMappings.get(col.type.__class__, None) - if type: - # Make the column type be an instance of this type. - type = type() - else: - # We must already be a model type, no need to map from the - # database-specific types. - type = col.type + + type_ = col.type + for cls in col.type.__class__.__mro__: + if cls.__module__ == 'sqlalchemy.types' and \ + not cls.__name__.isupper(): + if cls is not type_.__class__: + type_ = cls() + break data = { 'name': name, - 'type': type, + 'type': type_, 'constraints': ', '.join([repr(cn) for cn in col.constraints]), 'args': ks and ks or ''} diff --git a/migrate/versioning/schemadiff.py b/migrate/versioning/schemadiff.py index 6f300b3..8a06643 100644 --- a/migrate/versioning/schemadiff.py +++ b/migrate/versioning/schemadiff.py @@ -2,7 +2,7 @@ Schema differencing support. """ import sqlalchemy - +from migrate.changeset import SQLA_06 def getDiffOfModelAgainstDatabase(model, conn, excludeTables=None): """ @@ -55,9 +55,25 @@ class SchemaDiff(object): """ # Setup common variables. cc = self.conn.contextual_connect() - schemagenerator = self.conn.dialect.schemagenerator( - self.conn.dialect, cc) - + if SQLA_06: + from sqlalchemy.ext import compiler + from sqlalchemy.schema import DDLElement + class DefineColumn(DDLElement): + def __init__(self, col): + self.col = col + + @compiler.compiles(DefineColumn) + def compile(elem, compiler, **kw): + return compiler.get_column_specification(elem.col) + + def get_column_specification(col): + return str(DefineColumn(col).compile(dialect=self.conn.dialect)) + else: + schemagenerator = self.conn.dialect.schemagenerator( + self.conn.dialect, cc) + def get_column_specification(col): + return schemagenerator.get_column_specification(col) + # For each in model, find missing in database. for modelName, modelTable in self.model.tables.items(): if modelName in self.excludeTables: @@ -89,15 +105,16 @@ class SchemaDiff(object): # Find missing columns in model. for databaseCol in reflectedTable.columns: + + # TODO: no test coverage here? (mrb) + modelCol = modelTable.columns.get(databaseCol.name, None) if modelCol: # Compare attributes of column. modelDecl = \ - schemagenerator.get_column_specification( - modelCol) + get_column_specification(modelCol) databaseDecl = \ - schemagenerator.get_column_specification( - databaseCol) + get_column_specification(databaseCol) if modelDecl != databaseDecl: # Unfortunately, sometimes the database # decl won't quite match the model, even diff --git a/migrate/versioning/script/py.py b/migrate/versioning/script/py.py index 6a089ee..15f68b0 100644 --- a/migrate/versioning/script/py.py +++ b/migrate/versioning/script/py.py @@ -112,7 +112,7 @@ class PythonScript(base.BaseScript): """ buf = StringIO() args['engine_arg_strategy'] = 'mock' - args['engine_arg_executor'] = lambda s, p = '': buf.write(s + p) + args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p) engine = construct_engine(url, **args) self.run(engine, step) diff --git a/test/versioning/test_shell.py b/test/versioning/test_shell.py index 8ffd9f0..3544502 100644 --- a/test/versioning/test_shell.py +++ b/test/versioning/test_shell.py @@ -18,7 +18,7 @@ from test import fixture class Shell(fixture.Shell): - _cmd = os.path.join('python migrate', 'versioning', 'shell.py') + _cmd = os.path.join(sys.executable + ' migrate', 'versioning', 'shell.py') @classmethod def cmd(cls, *args): @@ -509,21 +509,21 @@ class TestShellDatabase(Shell, fixture.DB): open(model_path, 'w').write(script_preamble + script_text) # Model is defined but database is empty. - output, exitcode = self.output_and_exitcode('python %s compare_model_to_db' % script_path) + output, exitcode = self.output_and_exitcode('%s %s compare_model_to_db' % (sys.executable, script_path)) assert "tables missing in database: tmp_account_rundiffs" in output, output # Test Deprecation - output, exitcode = self.output_and_exitcode('python %s compare_model_to_db --model=testmodel.meta' % script_path) + output, exitcode = self.output_and_exitcode('%s %s compare_model_to_db --model=testmodel.meta' % (sys.executable, script_path)) assert "tables missing in database: tmp_account_rundiffs" in output, output # Update db to latest model. - output, exitcode = self.output_and_exitcode('python %s update_db_from_model' % script_path) + output, exitcode = self.output_and_exitcode('%s %s update_db_from_model' % (sys.executable, script_path)) self.assertEquals(exitcode, None) self.assertEquals(self.cmd_version(repos_path),0) self.assertEquals(self.cmd_db_version(self.url,repos_path),0) # version did not get bumped yet because new version not yet created - output, exitcode = self.output_and_exitcode('python %s compare_model_to_db' % script_path) + output, exitcode = self.output_and_exitcode('%s %s compare_model_to_db' % (sys.executable, script_path)) assert "No schema diffs" in output, output - output, exitcode = self.output_and_exitcode('python %s create_model' % script_path) + output, exitcode = self.output_and_exitcode('%s %s create_model' % (sys.executable, script_path)) output = output.replace(genmodel.HEADER.strip(), '') # need strip b/c output_and_exitcode called strip assert """tmp_account_rundiffs = Table('tmp_account_rundiffs', meta, Column('id', Integer(), primary_key=True, nullable=False), @@ -531,9 +531,9 @@ class TestShellDatabase(Shell, fixture.DB): Column('passwd', String(length=None, convert_unicode=False, assert_unicode=None)),""" in output.strip(), output # We're happy with db changes, make first db upgrade script to go from version 0 -> 1. - output, exitcode = self.output_and_exitcode('python %s make_update_script_for_model' % script_path) # intentionally omit a parameter + output, exitcode = self.output_and_exitcode('%s %s make_update_script_for_model' % (sys.executable, script_path)) # intentionally omit a parameter self.assertEquals('Not enough arguments' in output, True) - output, exitcode = self.output_and_exitcode('python %s make_update_script_for_model --oldmodel=oldtestmodel:meta' % script_path) + output, exitcode = self.output_and_exitcode('%s %s make_update_script_for_model --oldmodel=oldtestmodel:meta' % (sys.executable, script_path)) self.assertEqualsIgnoreWhitespace(output, """from sqlalchemy import * from migrate import * @@ -560,9 +560,9 @@ def downgrade(migrate_engine): self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc')) upgrade_script_path = '%s/versions/001_Desc.py' % repos_path open(upgrade_script_path, 'w').write(output) - #output, exitcode = self.output_and_exitcode('python %s test %s' % (script_path, upgrade_script_path)) # no, we already upgraded the db above + #output, exitcode = self.output_and_exitcode('%s %s test %s' % (sys.executable, script_path, upgrade_script_path)) # no, we already upgraded the db above #self.assertEquals(output, "") - output, exitcode = self.output_and_exitcode('python %s update_db_from_model' % script_path) # bump the db_version + output, exitcode = self.output_and_exitcode('%s %s update_db_from_model' % (sys.executable, script_path)) # bump the db_version self.assertEquals(exitcode, None) self.assertEquals(self.cmd_version(repos_path),1) self.assertEquals(self.cmd_db_version(self.url,repos_path),1) |