summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authoriElectric <unknown>2009-06-29 10:18:03 +0000
committeriElectric <unknown>2009-06-29 10:18:03 +0000
commitde3c53989d0f113055e79ce7c3db1dd72f4067ce (patch)
tree922a6d32990d8e46e175930cd861ff77f532fe2a
parentb546dded1f488a3c720375cae991130817a1ea03 (diff)
downloadsqlalchemy-migrate-de3c53989d0f113055e79ce7c3db1dd72f4067ce.tar.gz
add support for SA 0.6 by Michael Bayer
-rw-r--r--docs/changelog.rst1
-rw-r--r--migrate/changeset/__init__.py12
-rw-r--r--migrate/changeset/ansisql.py165
-rw-r--r--migrate/changeset/constraint.py18
-rw-r--r--migrate/changeset/databases/firebird.py5
-rw-r--r--migrate/changeset/databases/mysql.py66
-rw-r--r--migrate/changeset/databases/oracle.py9
-rw-r--r--migrate/changeset/databases/postgres.py9
-rw-r--r--migrate/changeset/databases/sqlite.py11
-rw-r--r--migrate/changeset/databases/visitor.py22
-rw-r--r--migrate/changeset/schema.py10
-rw-r--r--migrate/versioning/genmodel.py23
-rw-r--r--migrate/versioning/schemadiff.py33
-rw-r--r--migrate/versioning/script/py.py2
-rw-r--r--test/versioning/test_shell.py20
15 files changed, 252 insertions, 154 deletions
diff --git a/docs/changelog.rst b/docs/changelog.rst
index 58e6ec3..89a8238 100644
--- a/docs/changelog.rst
+++ b/docs/changelog.rst
@@ -1,6 +1,7 @@
0.5.5
-----
+- added support for SQLAlchemy 0.6 (missing oracle and firebird) by Michael Bayer
- alter, create, drop column / rename table / rename index constructs now accept `alter_metadata` parameter. If True, it will modify Column/Table objects according to changes. Otherwise, everything will be untouched.
- complete refactoring of :class:`~migrate.changeset.schema.ColumnDelta` (fixes issue 23)
- added support for :ref:`firebird <firebird-d>`
diff --git a/migrate/changeset/__init__.py b/migrate/changeset/__init__.py
index 940c23f..25fc73b 100644
--- a/migrate/changeset/__init__.py
+++ b/migrate/changeset/__init__.py
@@ -6,9 +6,21 @@
"""
import sqlalchemy
+from sqlalchemy import __version__ as _sa_version
+import re
+
+_sa_version = tuple(int(re.match("\d+", x).group(0)) for x in _sa_version.split("."))
+SQLA_06 = _sa_version >= (0, 6)
+
+del re
+del _sa_version
+
from migrate.changeset.schema import *
from migrate.changeset.constraint import *
+
+
+
sqlalchemy.schema.Table.__bases__ += (ChangesetTable, )
sqlalchemy.schema.Column.__bases__ += (ChangesetColumn, )
sqlalchemy.schema.Index.__bases__ += (ChangesetIndex, )
diff --git a/migrate/changeset/ansisql.py b/migrate/changeset/ansisql.py
index f0dfed5..ee585a7 100644
--- a/migrate/changeset/ansisql.py
+++ b/migrate/changeset/ansisql.py
@@ -5,23 +5,53 @@
things that just happen to work with multiple databases.
"""
import sqlalchemy as sa
+from sqlalchemy.schema import SchemaVisitor
from sqlalchemy.engine.default import DefaultDialect
from sqlalchemy.schema import (ForeignKeyConstraint,
PrimaryKeyConstraint,
CheckConstraint,
UniqueConstraint,
Index)
-from sqlalchemy.sql.compiler import SchemaGenerator, SchemaDropper
-from migrate.changeset import exceptions, constraint
+from migrate.changeset import exceptions, constraint, SQLA_06
+import StringIO
+if not SQLA_06:
+ from sqlalchemy.sql.compiler import SchemaGenerator, SchemaDropper
+else:
+ from sqlalchemy.schema import AddConstraint, DropConstraint
+ from sqlalchemy.sql.compiler import DDLCompiler
+ SchemaGenerator = SchemaDropper = DDLCompiler
-SchemaIterator = sa.engine.SchemaIterator
+class AlterTableVisitor(SchemaVisitor):
+ """Common operations for ``ALTER TABLE`` statements."""
+ def append(self, s):
+ """Append content to the SchemaIterator's query buffer."""
-class AlterTableVisitor(SchemaIterator):
- """Common operations for ``ALTER TABLE`` statements."""
+ self.buffer.write(s)
+
+ def execute(self):
+ """Execute the contents of the SchemaIterator's buffer."""
+
+ try:
+ return self.connection.execute(self.buffer.getvalue())
+ finally:
+ self.buffer.truncate(0)
+
+ def __init__(self, dialect, connection, **kw):
+ self.connection = connection
+ self.buffer = StringIO.StringIO()
+ self.preparer = dialect.identifier_preparer
+ self.dialect = dialect
+ def traverse_single(self, elem):
+ ret = super(AlterTableVisitor, self).traverse_single(elem)
+ if ret:
+ # adapt to 0.6 which uses a string-returning
+ # object
+ self.append(ret)
+
def _to_table(self, param):
"""Returns the table object for the given param object."""
if isinstance(param, (sa.Column, sa.Index, sa.schema.Constraint)):
@@ -88,6 +118,9 @@ class ANSIColumnGenerator(AlterTableVisitor, SchemaGenerator):
name=column.primary_key_name)
cons.create()
+ if SQLA_06:
+ def add_foreignkey(self, fk):
+ self.connection.execute(AddConstraint(fk))
class ANSIColumnDropper(AlterTableVisitor, SchemaDropper):
"""Extends ANSI SQL dropper for column dropping (``ALTER TABLE
@@ -181,7 +214,10 @@ class ANSISchemaChanger(AlterTableVisitor, SchemaGenerator):
def _visit_column_type(self, table, column, delta):
type_ = delta['type']
- type_text = type_.dialect_impl(self.dialect).get_col_spec()
+ if SQLA_06:
+ type_text = str(type_.compile(dialect=self.dialect))
+ else:
+ type_text = type_.dialect_impl(self.dialect).get_col_spec()
self.append("TYPE %s" % type_text)
def _visit_column_name(self, table, column, delta):
@@ -225,60 +261,75 @@ class ANSIConstraintCommon(AlterTableVisitor):
def visit_migrate_unique_constraint(self, *p, **k):
self._visit_constraint(*p, **k)
+if SQLA_06:
+ class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
+ def _visit_constraint(self, constraint):
+ constraint.name = self.get_constraint_name(constraint)
+ self.append(self.process(AddConstraint(constraint)))
+ self.execute()
-class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
+ class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
+ def _visit_constraint(self, constraint):
+ constraint.name = self.get_constraint_name(constraint)
+ self.append(self.process(DropConstraint(constraint, cascade=constraint.cascade)))
+ self.execute()
- def get_constraint_specification(self, cons, **kwargs):
- """Constaint SQL generators.
-
- We cannot use SA visitors because they append comma.
- """
- if isinstance(cons, PrimaryKeyConstraint):
- if cons.name is not None:
- self.append("CONSTRAINT %s " % self.preparer.format_constraint(cons))
- self.append("PRIMARY KEY ")
- self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
- for c in cons))
- self.define_constraint_deferrability(cons)
- elif isinstance(cons, ForeignKeyConstraint):
- self.define_foreign_key(cons)
- elif isinstance(cons, CheckConstraint):
- if cons.name is not None:
- self.append("CONSTRAINT %s " %
- self.preparer.format_constraint(cons))
- self.append("CHECK (%s)" % cons.sqltext)
- self.define_constraint_deferrability(cons)
- elif isinstance(cons, UniqueConstraint):
- if cons.name is not None:
- self.append("CONSTRAINT %s " %
- self.preparer.format_constraint(cons))
- self.append("UNIQUE (%s)" % \
- (', '.join(self.preparer.quote(c.name, c.quote) for c in cons)))
- self.define_constraint_deferrability(cons)
- else:
- raise exceptions.InvalidConstraintError(cons)
+else:
+ class ANSIConstraintGenerator(ANSIConstraintCommon, SchemaGenerator):
- def _visit_constraint(self, constraint):
- table = self.start_alter_table(constraint)
- constraint.name = self.get_constraint_name(constraint)
- self.append("ADD ")
- self.get_constraint_specification(constraint)
- self.execute()
-
-
-class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
-
- def _visit_constraint(self, constraint):
- self.start_alter_table(constraint)
- self.append("DROP CONSTRAINT ")
- constraint.name = self.get_constraint_name(constraint)
- self.append(self.preparer.format_constraint(constraint))
- if constraint.cascade:
- self.cascade_constraint(constraint)
- self.execute()
-
- def cascade_constraint(self, constraint):
- self.append(" CASCADE")
+ def get_constraint_specification(self, cons, **kwargs):
+ """Constaint SQL generators.
+
+ We cannot use SA visitors because they append comma.
+ """
+
+ if isinstance(cons, PrimaryKeyConstraint):
+ if cons.name is not None:
+ self.append("CONSTRAINT %s " % self.preparer.format_constraint(cons))
+ self.append("PRIMARY KEY ")
+ self.append("(%s)" % ', '.join(self.preparer.quote(c.name, c.quote)
+ for c in cons))
+ self.define_constraint_deferrability(cons)
+ elif isinstance(cons, ForeignKeyConstraint):
+ self.define_foreign_key(cons)
+ elif isinstance(cons, CheckConstraint):
+ if cons.name is not None:
+ self.append("CONSTRAINT %s " %
+ self.preparer.format_constraint(cons))
+ self.append("CHECK (%s)" % cons.sqltext)
+ self.define_constraint_deferrability(cons)
+ elif isinstance(cons, UniqueConstraint):
+ if cons.name is not None:
+ self.append("CONSTRAINT %s " %
+ self.preparer.format_constraint(cons))
+ self.append("UNIQUE (%s)" % \
+ (', '.join(self.preparer.quote(c.name, c.quote) for c in cons)))
+ self.define_constraint_deferrability(cons)
+ else:
+ raise exceptions.InvalidConstraintError(cons)
+
+ def _visit_constraint(self, constraint):
+
+ table = self.start_alter_table(constraint)
+ constraint.name = self.get_constraint_name(constraint)
+ self.append("ADD ")
+ self.get_constraint_specification(constraint)
+ self.execute()
+
+
+ class ANSIConstraintDropper(ANSIConstraintCommon, SchemaDropper):
+
+ def _visit_constraint(self, constraint):
+ self.start_alter_table(constraint)
+ self.append("DROP CONSTRAINT ")
+ constraint.name = self.get_constraint_name(constraint)
+ self.append(self.preparer.format_constraint(constraint))
+ if constraint.cascade:
+ self.cascade_constraint(constraint)
+ self.execute()
+
+ def cascade_constraint(self, constraint):
+ self.append(" CASCADE")
class ANSIDialect(DefaultDialect):
diff --git a/migrate/changeset/constraint.py b/migrate/changeset/constraint.py
index b09b30a..72251f5 100644
--- a/migrate/changeset/constraint.py
+++ b/migrate/changeset/constraint.py
@@ -4,7 +4,7 @@
from sqlalchemy import schema
from migrate.changeset.exceptions import *
-
+from migrate.changeset import SQLA_06
class ConstraintChangeset(object):
"""Base class for Constraint classes."""
@@ -54,7 +54,10 @@ class ConstraintChangeset(object):
"""
self.cascade = kw.pop('cascade', False)
self.__do_imports('constraintdropper', *a, **kw)
- self.columns.clear()
+ # the spirit of Constraint objects is that they
+ # are immutable (just like in a DB. they're only ADDed
+ # or DROPped).
+ #self.columns.clear()
return self
@@ -69,7 +72,7 @@ class PrimaryKeyConstraint(ConstraintChangeset, schema.PrimaryKeyConstraint):
:type cols: strings or Column instances
"""
- __visit_name__ = 'migrate_primary_key_constraint'
+ __migrate_visit_name__ = 'migrate_primary_key_constraint'
def __init__(self, *cols, **kwargs):
colnames, table = self._normalize_columns(cols)
@@ -97,7 +100,7 @@ class ForeignKeyConstraint(ConstraintChangeset, schema.ForeignKeyConstraint):
:type refcolumns: list of strings or Column instances
"""
- __visit_name__ = 'migrate_foreign_key_constraint'
+ __migrate_visit_name__ = 'migrate_foreign_key_constraint'
def __init__(self, columns, refcolumns, *args, **kwargs):
colnames, table = self._normalize_columns(columns)
@@ -139,7 +142,7 @@ class CheckConstraint(ConstraintChangeset, schema.CheckConstraint):
:type sqltext: string
"""
- __visit_name__ = 'migrate_check_constraint'
+ __migrate_visit_name__ = 'migrate_check_constraint'
def __init__(self, sqltext, *args, **kwargs):
cols = kwargs.pop('columns', [])
@@ -150,7 +153,8 @@ class CheckConstraint(ConstraintChangeset, schema.CheckConstraint):
table = kwargs.pop('table', table)
schema.CheckConstraint.__init__(self, sqltext, *args, **kwargs)
if table is not None:
- self.table = table
+ if not SQLA_06:
+ self.table = table
self._set_parent(table)
self.colnames = colnames
@@ -172,7 +176,7 @@ class UniqueConstraint(ConstraintChangeset, schema.UniqueConstraint):
.. versionadded:: 0.5.5
"""
- __visit_name__ = 'migrate_unique_constraint'
+ __migrate_visit_name__ = 'migrate_unique_constraint'
def __init__(self, *cols, **kwargs):
self.colnames, table = self._normalize_columns(cols)
diff --git a/migrate/changeset/databases/firebird.py b/migrate/changeset/databases/firebird.py
index d60cf00..5eacd58 100644
--- a/migrate/changeset/databases/firebird.py
+++ b/migrate/changeset/databases/firebird.py
@@ -1,14 +1,13 @@
"""
Firebird database specific implementations of changeset classes.
"""
-from sqlalchemy.databases import firebird as sa_base
from migrate.changeset import ansisql, exceptions
-
+# TODO: SQLA 0.6 has not migrated the FB dialect over yet
+from sqlalchemy.databases import firebird as sa_base
FBSchemaGenerator = sa_base.FBSchemaGenerator
-
class FBColumnGenerator(FBSchemaGenerator, ansisql.ANSIColumnGenerator):
"""Firebird column generator implementation."""
diff --git a/migrate/changeset/databases/mysql.py b/migrate/changeset/databases/mysql.py
index 5b5a16e..6655a42 100644
--- a/migrate/changeset/databases/mysql.py
+++ b/migrate/changeset/databases/mysql.py
@@ -2,13 +2,13 @@
MySQL database specific implementations of changeset classes.
"""
+from migrate.changeset import ansisql, exceptions, SQLA_06
from sqlalchemy.databases import mysql as sa_base
-from migrate.changeset import ansisql, exceptions
-
-
-MySQLSchemaGenerator = sa_base.MySQLSchemaGenerator
-
+if not SQLA_06:
+ MySQLSchemaGenerator = sa_base.MySQLSchemaGenerator
+else:
+ MySQLSchemaGenerator = sa_base.MySQLDDLCompiler
class MySQLColumnGenerator(MySQLSchemaGenerator, ansisql.ANSIColumnGenerator):
pass
@@ -39,31 +39,37 @@ class MySQLSchemaChanger(MySQLSchemaGenerator, ansisql.ANSISchemaChanger):
class MySQLConstraintGenerator(ansisql.ANSIConstraintGenerator):
pass
-
-class MySQLConstraintDropper(ansisql.ANSIConstraintDropper):
-
- def visit_migrate_primary_key_constraint(self, constraint):
- self.start_alter_table(constraint)
- self.append("DROP PRIMARY KEY")
- self.execute()
-
- def visit_migrate_foreign_key_constraint(self, constraint):
- self.start_alter_table(constraint)
- self.append("DROP FOREIGN KEY ")
- constraint.name = self.get_constraint_name(constraint)
- self.append(self.preparer.format_constraint(constraint))
- self.execute()
-
- def visit_migrate_check_constraint(self, *p, **k):
- raise exceptions.NotSupportedError("MySQL does not support CHECK"
- " constraints, use triggers instead.")
-
- def visit_migrate_unique_constraint(self, constraint, *p, **k):
- self.start_alter_table(constraint)
- self.append('DROP INDEX ')
- constraint.name = self.get_constraint_name(constraint)
- self.append(self.preparer.format_constraint(constraint))
- self.execute()
+if SQLA_06:
+ class MySQLConstraintDropper(MySQLSchemaGenerator, ansisql.ANSIConstraintDropper):
+ def visit_migrate_check_constraint(self, *p, **k):
+ raise exceptions.NotSupportedError("MySQL does not support CHECK"
+ " constraints, use triggers instead.")
+
+else:
+ class MySQLConstraintDropper(ansisql.ANSIConstraintDropper):
+
+ def visit_migrate_primary_key_constraint(self, constraint):
+ self.start_alter_table(constraint)
+ self.append("DROP PRIMARY KEY")
+ self.execute()
+
+ def visit_migrate_foreign_key_constraint(self, constraint):
+ self.start_alter_table(constraint)
+ self.append("DROP FOREIGN KEY ")
+ constraint.name = self.get_constraint_name(constraint)
+ self.append(self.preparer.format_constraint(constraint))
+ self.execute()
+
+ def visit_migrate_check_constraint(self, *p, **k):
+ raise exceptions.NotSupportedError("MySQL does not support CHECK"
+ " constraints, use triggers instead.")
+
+ def visit_migrate_unique_constraint(self, constraint, *p, **k):
+ self.start_alter_table(constraint)
+ self.append('DROP INDEX ')
+ constraint.name = self.get_constraint_name(constraint)
+ self.append(self.preparer.format_constraint(constraint))
+ self.execute()
class MySQLDialect(ansisql.ANSIDialect):
diff --git a/migrate/changeset/databases/oracle.py b/migrate/changeset/databases/oracle.py
index 93c9f8f..fd2749a 100644
--- a/migrate/changeset/databases/oracle.py
+++ b/migrate/changeset/databases/oracle.py
@@ -2,12 +2,17 @@
Oracle database specific implementations of changeset classes.
"""
import sqlalchemy as sa
-from sqlalchemy.databases import oracle as sa_base
from migrate.changeset import ansisql, exceptions
+from sqlalchemy.databases import oracle as sa_base
+
+from migrate.changeset import ansisql, exceptions, SQLA_06
-OracleSchemaGenerator = sa_base.OracleSchemaGenerator
+if not SQLA_06:
+ OracleSchemaGenerator = sa_base.OracleSchemaGenerator
+else:
+ OracleSchemaGenerator = sa_base.OracleDDLCompiler
class OracleColumnGenerator(OracleSchemaGenerator, ansisql.ANSIColumnGenerator):
diff --git a/migrate/changeset/databases/postgres.py b/migrate/changeset/databases/postgres.py
index bcdc08b..2c36ed1 100644
--- a/migrate/changeset/databases/postgres.py
+++ b/migrate/changeset/databases/postgres.py
@@ -3,12 +3,13 @@
.. _`PostgreSQL`: http://www.postgresql.org/
"""
-from migrate.changeset import ansisql
+from migrate.changeset import ansisql, SQLA_06
from sqlalchemy.databases import postgres as sa_base
-#import sqlalchemy as sa
-
-PGSchemaGenerator = sa_base.PGSchemaGenerator
+if not SQLA_06:
+ PGSchemaGenerator = sa_base.PGSchemaGenerator
+else:
+ PGSchemaGenerator = sa_base.PGDDLCompiler
class PGColumnGenerator(PGSchemaGenerator, ansisql.ANSIColumnGenerator):
diff --git a/migrate/changeset/databases/sqlite.py b/migrate/changeset/databases/sqlite.py
index 59902b4..64be9bf 100644
--- a/migrate/changeset/databases/sqlite.py
+++ b/migrate/changeset/databases/sqlite.py
@@ -8,10 +8,12 @@ from copy import copy
from sqlalchemy.databases import sqlite as sa_base
-from migrate.changeset import ansisql, exceptions
+from migrate.changeset import ansisql, exceptions, SQLA_06
-
-SQLiteSchemaGenerator = sa_base.SQLiteSchemaGenerator
+if not SQLA_06:
+ SQLiteSchemaGenerator = sa_base.SQLiteSchemaGenerator
+else:
+ SQLiteSchemaGenerator = sa_base.SQLiteDDLCompiler
class SQLiteCommon(object):
@@ -52,8 +54,7 @@ class SQLiteHelper(SQLiteCommon):
table.indexes = ixbackup
table.constraints = consbackup
-
-class SQLiteColumnGenerator(SQLiteSchemaGenerator, SQLiteCommon,
+class SQLiteColumnGenerator(SQLiteSchemaGenerator, SQLiteCommon,
ansisql.ANSIColumnGenerator):
"""SQLite ColumnGenerator"""
diff --git a/migrate/changeset/databases/visitor.py b/migrate/changeset/databases/visitor.py
index 18f1ac0..6db2d51 100644
--- a/migrate/changeset/databases/visitor.py
+++ b/migrate/changeset/databases/visitor.py
@@ -13,12 +13,12 @@ from migrate.changeset.databases import (sqlite,
# Map SA dialects to the corresponding Migrate extensions
DIALECTS = {
- sa.engine.default.DefaultDialect: ansisql.ANSIDialect,
- sa.databases.sqlite.SQLiteDialect: sqlite.SQLiteDialect,
- sa.databases.postgres.PGDialect: postgres.PGDialect,
- sa.databases.mysql.MySQLDialect: mysql.MySQLDialect,
- sa.databases.oracle.OracleDialect: oracle.OracleDialect,
- sa.databases.firebird.FBDialect: firebird.FBDialect,
+ "default": ansisql.ANSIDialect,
+ "sqlite": sqlite.SQLiteDialect,
+ "postgres": postgres.PGDialect,
+ "mysql": mysql.MySQLDialect,
+ "oracle": oracle.OracleDialect,
+ "firebird": firebird.FBDialect,
}
@@ -47,8 +47,8 @@ def get_dialect_visitor(sa_dialect, name):
"""
# map sa dialect to migrate dialect and return visitor
- sa_dialect_cls = sa_dialect.__class__
- migrate_dialect_cls = DIALECTS[sa_dialect_cls]
+ sa_dialect_name = getattr(sa_dialect, 'name', 'default')
+ migrate_dialect_cls = DIALECTS[sa_dialect_name]
visitor = getattr(migrate_dialect_cls, name)
# bind preparer
@@ -61,6 +61,10 @@ def run_single_visitor(engine, visitorcallable, element, **kwargs):
conn = engine.contextual_connect(close_with_result=False)
try:
visitor = visitorcallable(engine.dialect, conn)
- getattr(visitor, 'visit_' + element.__visit_name__)(element, **kwargs)
+ if hasattr(element, '__migrate_visit_name__'):
+ fn = getattr(visitor, 'visit_' + element.__migrate_visit_name__)
+ else:
+ fn = getattr(visitor, 'visit_' + element.__visit_name__)
+ fn(element, **kwargs)
finally:
conn.close()
diff --git a/migrate/changeset/schema.py b/migrate/changeset/schema.py
index 3ae9b46..ab839b1 100644
--- a/migrate/changeset/schema.py
+++ b/migrate/changeset/schema.py
@@ -4,6 +4,7 @@
from UserDict import DictMixin
import sqlalchemy
+from migrate.changeset import SQLA_06
from migrate.changeset.exceptions import *
from migrate.changeset.databases.visitor import (get_engine_visitor,
run_single_visitor)
@@ -310,7 +311,7 @@ class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
def process_column(self, column):
"""Processes default values for column"""
# XXX: this is a snippet from SA processing of positional parameters
- if column.args:
+ if not SQLA_06 and column.args:
toinit = list(column.args)
else:
toinit = list()
@@ -328,7 +329,9 @@ class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
for_update=True))
if toinit:
column._init_items(*toinit)
- column.args = []
+
+ if not SQLA_06:
+ column.args = []
def _get_table(self):
return getattr(self, '_table', None)
@@ -365,9 +368,6 @@ class ColumnDelta(DictMixin, sqlalchemy.schema.SchemaItem):
self.current_name = column.name
if self.alter_metadata:
self._result_column = column
- # remove column from table, nothing has changed yet
- if self.table:
- column.remove_from_table(self.table)
else:
self._result_column = column.copy_fixed()
diff --git a/migrate/versioning/genmodel.py b/migrate/versioning/genmodel.py
index 91f3976..ba455b0 100644
--- a/migrate/versioning/genmodel.py
+++ b/migrate/versioning/genmodel.py
@@ -34,10 +34,7 @@ class ModelGenerator(object):
def __init__(self, diff, declarative=False):
self.diff = diff
self.declarative = declarative
- # is there an easier way to get this?
- dialectModule = sys.modules[self.diff.conn.dialect.__module__]
- self.colTypeMappings = dict((v, k) for k, v in \
- dialectModule.colspecs.items())
+
def column_repr(self, col):
kwarg = []
@@ -63,18 +60,18 @@ class ModelGenerator(object):
# crs: not sure if this is good idea, but it gets rid of extra
# u''
name = col.name.encode('utf8')
- type = self.colTypeMappings.get(col.type.__class__, None)
- if type:
- # Make the column type be an instance of this type.
- type = type()
- else:
- # We must already be a model type, no need to map from the
- # database-specific types.
- type = col.type
+
+ type_ = col.type
+ for cls in col.type.__class__.__mro__:
+ if cls.__module__ == 'sqlalchemy.types' and \
+ not cls.__name__.isupper():
+ if cls is not type_.__class__:
+ type_ = cls()
+ break
data = {
'name': name,
- 'type': type,
+ 'type': type_,
'constraints': ', '.join([repr(cn) for cn in col.constraints]),
'args': ks and ks or ''}
diff --git a/migrate/versioning/schemadiff.py b/migrate/versioning/schemadiff.py
index 6f300b3..8a06643 100644
--- a/migrate/versioning/schemadiff.py
+++ b/migrate/versioning/schemadiff.py
@@ -2,7 +2,7 @@
Schema differencing support.
"""
import sqlalchemy
-
+from migrate.changeset import SQLA_06
def getDiffOfModelAgainstDatabase(model, conn, excludeTables=None):
"""
@@ -55,9 +55,25 @@ class SchemaDiff(object):
"""
# Setup common variables.
cc = self.conn.contextual_connect()
- schemagenerator = self.conn.dialect.schemagenerator(
- self.conn.dialect, cc)
-
+ if SQLA_06:
+ from sqlalchemy.ext import compiler
+ from sqlalchemy.schema import DDLElement
+ class DefineColumn(DDLElement):
+ def __init__(self, col):
+ self.col = col
+
+ @compiler.compiles(DefineColumn)
+ def compile(elem, compiler, **kw):
+ return compiler.get_column_specification(elem.col)
+
+ def get_column_specification(col):
+ return str(DefineColumn(col).compile(dialect=self.conn.dialect))
+ else:
+ schemagenerator = self.conn.dialect.schemagenerator(
+ self.conn.dialect, cc)
+ def get_column_specification(col):
+ return schemagenerator.get_column_specification(col)
+
# For each in model, find missing in database.
for modelName, modelTable in self.model.tables.items():
if modelName in self.excludeTables:
@@ -89,15 +105,16 @@ class SchemaDiff(object):
# Find missing columns in model.
for databaseCol in reflectedTable.columns:
+
+ # TODO: no test coverage here? (mrb)
+
modelCol = modelTable.columns.get(databaseCol.name, None)
if modelCol:
# Compare attributes of column.
modelDecl = \
- schemagenerator.get_column_specification(
- modelCol)
+ get_column_specification(modelCol)
databaseDecl = \
- schemagenerator.get_column_specification(
- databaseCol)
+ get_column_specification(databaseCol)
if modelDecl != databaseDecl:
# Unfortunately, sometimes the database
# decl won't quite match the model, even
diff --git a/migrate/versioning/script/py.py b/migrate/versioning/script/py.py
index 6a089ee..15f68b0 100644
--- a/migrate/versioning/script/py.py
+++ b/migrate/versioning/script/py.py
@@ -112,7 +112,7 @@ class PythonScript(base.BaseScript):
"""
buf = StringIO()
args['engine_arg_strategy'] = 'mock'
- args['engine_arg_executor'] = lambda s, p = '': buf.write(s + p)
+ args['engine_arg_executor'] = lambda s, p = '': buf.write(str(s) + p)
engine = construct_engine(url, **args)
self.run(engine, step)
diff --git a/test/versioning/test_shell.py b/test/versioning/test_shell.py
index 8ffd9f0..3544502 100644
--- a/test/versioning/test_shell.py
+++ b/test/versioning/test_shell.py
@@ -18,7 +18,7 @@ from test import fixture
class Shell(fixture.Shell):
- _cmd = os.path.join('python migrate', 'versioning', 'shell.py')
+ _cmd = os.path.join(sys.executable + ' migrate', 'versioning', 'shell.py')
@classmethod
def cmd(cls, *args):
@@ -509,21 +509,21 @@ class TestShellDatabase(Shell, fixture.DB):
open(model_path, 'w').write(script_preamble + script_text)
# Model is defined but database is empty.
- output, exitcode = self.output_and_exitcode('python %s compare_model_to_db' % script_path)
+ output, exitcode = self.output_and_exitcode('%s %s compare_model_to_db' % (sys.executable, script_path))
assert "tables missing in database: tmp_account_rundiffs" in output, output
# Test Deprecation
- output, exitcode = self.output_and_exitcode('python %s compare_model_to_db --model=testmodel.meta' % script_path)
+ output, exitcode = self.output_and_exitcode('%s %s compare_model_to_db --model=testmodel.meta' % (sys.executable, script_path))
assert "tables missing in database: tmp_account_rundiffs" in output, output
# Update db to latest model.
- output, exitcode = self.output_and_exitcode('python %s update_db_from_model' % script_path)
+ output, exitcode = self.output_and_exitcode('%s %s update_db_from_model' % (sys.executable, script_path))
self.assertEquals(exitcode, None)
self.assertEquals(self.cmd_version(repos_path),0)
self.assertEquals(self.cmd_db_version(self.url,repos_path),0) # version did not get bumped yet because new version not yet created
- output, exitcode = self.output_and_exitcode('python %s compare_model_to_db' % script_path)
+ output, exitcode = self.output_and_exitcode('%s %s compare_model_to_db' % (sys.executable, script_path))
assert "No schema diffs" in output, output
- output, exitcode = self.output_and_exitcode('python %s create_model' % script_path)
+ output, exitcode = self.output_and_exitcode('%s %s create_model' % (sys.executable, script_path))
output = output.replace(genmodel.HEADER.strip(), '') # need strip b/c output_and_exitcode called strip
assert """tmp_account_rundiffs = Table('tmp_account_rundiffs', meta,
Column('id', Integer(), primary_key=True, nullable=False),
@@ -531,9 +531,9 @@ class TestShellDatabase(Shell, fixture.DB):
Column('passwd', String(length=None, convert_unicode=False, assert_unicode=None)),""" in output.strip(), output
# We're happy with db changes, make first db upgrade script to go from version 0 -> 1.
- output, exitcode = self.output_and_exitcode('python %s make_update_script_for_model' % script_path) # intentionally omit a parameter
+ output, exitcode = self.output_and_exitcode('%s %s make_update_script_for_model' % (sys.executable, script_path)) # intentionally omit a parameter
self.assertEquals('Not enough arguments' in output, True)
- output, exitcode = self.output_and_exitcode('python %s make_update_script_for_model --oldmodel=oldtestmodel:meta' % script_path)
+ output, exitcode = self.output_and_exitcode('%s %s make_update_script_for_model --oldmodel=oldtestmodel:meta' % (sys.executable, script_path))
self.assertEqualsIgnoreWhitespace(output,
"""from sqlalchemy import *
from migrate import *
@@ -560,9 +560,9 @@ def downgrade(migrate_engine):
self.assertSuccess(self.cmd('script', '--repository=%s' % repos_path, 'Desc'))
upgrade_script_path = '%s/versions/001_Desc.py' % repos_path
open(upgrade_script_path, 'w').write(output)
- #output, exitcode = self.output_and_exitcode('python %s test %s' % (script_path, upgrade_script_path)) # no, we already upgraded the db above
+ #output, exitcode = self.output_and_exitcode('%s %s test %s' % (sys.executable, script_path, upgrade_script_path)) # no, we already upgraded the db above
#self.assertEquals(output, "")
- output, exitcode = self.output_and_exitcode('python %s update_db_from_model' % script_path) # bump the db_version
+ output, exitcode = self.output_and_exitcode('%s %s update_db_from_model' % (sys.executable, script_path)) # bump the db_version
self.assertEquals(exitcode, None)
self.assertEquals(self.cmd_version(repos_path),1)
self.assertEquals(self.cmd_db_version(self.url,repos_path),1)