diff options
author | christian.simms <unknown> | 2008-04-06 20:58:37 +0000 |
---|---|---|
committer | christian.simms <unknown> | 2008-04-06 20:58:37 +0000 |
commit | 8af121c421f45a108894f56c8f713ae455410020 (patch) | |
tree | 0fef911c6c38f1fb8f149bb1848e0b5ecadcbf85 | |
parent | c13931b6b911dde2b3732c23428da7be64786a9e (diff) | |
download | sqalchemy-migrate-8af121c421f45a108894f56c8f713ae455410020.tar.gz |
Change make_update_script_for_model shell command to compare two versions of Python model (issue #12); add shell test for new diff'ing apis
-rw-r--r-- | migrate/versioning/api.py | 10 | ||||
-rw-r--r-- | migrate/versioning/schemadiff.py | 13 | ||||
-rw-r--r-- | migrate/versioning/script/py.py | 9 | ||||
-rw-r--r-- | test/fixture/base.py | 12 | ||||
-rw-r--r-- | test/fixture/shell.py | 2 | ||||
-rw-r--r-- | test/versioning/test_schemadiff.py | 17 | ||||
-rw-r--r-- | test/versioning/test_shell.py | 98 |
7 files changed, 134 insertions, 27 deletions
diff --git a/migrate/versioning/api.py b/migrate/versioning/api.py index 32ed33f..e65d7d5 100644 --- a/migrate/versioning/api.py +++ b/migrate/versioning/api.py @@ -301,18 +301,18 @@ def create_model(url,repository,**opts): engine=create_engine(url) print cls_schema.create_model(engine,repository) -def make_update_script_for_model(path,url,model,repository,**opts): - """%prog make_update_script_for_model PATH URL MODEL REPOSITORY_PATH +def make_update_script_for_model(url,oldmodel,model,repository,**opts): + """%prog make_update_script_for_model URL OLDMODEL MODEL REPOSITORY_PATH - Create a script changing the current (old) database to the current (new) Python model. + Create a script changing the old Python model to the new (current) Python model, sending to stdout. NOTE: This is EXPERIMENTAL. """ # TODO: get rid of EXPERIMENTAL label engine=create_engine(url) try: - cls_script_python.make_update_script_for_model(path,engine,model,repository,**opts) + print cls_script_python.make_update_script_for_model(engine,oldmodel,model,repository,**opts) except exceptions.PathFoundError,e: - raise exceptions.KnownError("The path %s already exists"%e.args[0]) + raise exceptions.KnownError("The path %s already exists"%e.args[0]) # TODO: get rid of this? if we don't add back path param def update_db_from_model(url,model,repository,**opts): """%prog update_db_from_model URL MODEL REPOSITORY_PATH diff --git a/migrate/versioning/schemadiff.py b/migrate/versioning/schemadiff.py index 1005986..78ff4dc 100644 --- a/migrate/versioning/schemadiff.py +++ b/migrate/versioning/schemadiff.py @@ -8,17 +8,26 @@ def getDiffOfModelAgainstDatabase(model, conn, excludeTables=None): ''' return SchemaDiff(model, conn, excludeTables) +def getDiffOfModelAgainstModel(oldmodel, model, conn, excludeTables=None): + ''' Return differences of model against database. + Returned object will evaluate to True if there are differences else False. + ''' + return SchemaDiff(model, conn, excludeTables, oldmodel=oldmodel) + class SchemaDiff(object): ''' Differences of model against database. ''' - def __init__(self, model, conn, excludeTables=None): + def __init__(self, model, conn, excludeTables=None, oldmodel=None): ''' Parameter model is your Python model's metadata and conn is an active database connection. ''' self.model = model self.conn = conn if not excludeTables: excludeTables = [] # [] can't be default value in Python parameter self.excludeTables = excludeTables - self.reflected_model = sqlalchemy.MetaData(conn, reflect=True) + if oldmodel: + self.reflected_model = oldmodel + else: + self.reflected_model = sqlalchemy.MetaData(conn, reflect=True) self.tablesMissingInDatabase, self.tablesMissingInModel, self.tablesWithDiff = [], [], [] self.colDiffs = {} self.compareModelToDatabase() diff --git a/migrate/versioning/script/py.py b/migrate/versioning/script/py.py index 7add4f4..1f2685a 100644 --- a/migrate/versioning/script/py.py +++ b/migrate/versioning/script/py.py @@ -20,16 +20,17 @@ class PythonScript(base.BaseScript): shutil.copy(src,path) @classmethod - def make_update_script_for_model(cls,path,engine,model,repository,**opts): + def make_update_script_for_model(cls,engine,oldmodel,model,repository,**opts): """Create a migration script""" - cls.require_notfound(path) + #cls.require_notfound(path) # TODO: yank? # Compute differences. if isinstance(repository, basestring): from migrate.versioning.repository import Repository # oh dear, an import cycle! repository=Repository(repository) + oldmodel = loadModel(oldmodel) model = loadModel(model) - diff = schemadiff.getDiffOfModelAgainstDatabase(model, engine, excludeTables=[repository.version_table]) + diff = schemadiff.getDiffOfModelAgainstModel(oldmodel, model, engine, excludeTables=[repository.version_table]) upgradeDecls, upgradeCommands = genmodel.ModelGenerator(diff).toUpgradePython() #downgradeCommands = genmodel.ModelGenerator(diff).toDowngradePython() @@ -43,7 +44,7 @@ class PythonScript(base.BaseScript): contents = contents.replace(search, upgradeDecls + '\n\n' + search, 1) if upgradeCommands: contents = contents.replace(' pass', upgradeCommands, 1) #if downgradeCommands: contents = contents.replace(' pass', downgradeCommands, 1) # TODO - open(path, 'w').write(contents) + return contents # TODO: reinstate? open(path, 'w').write(contents) @classmethod def verify_module(cls,path): diff --git a/test/fixture/base.py b/test/fixture/base.py index 5de4065..8f97103 100644 --- a/test/fixture/base.py +++ b/test/fixture/base.py @@ -23,6 +23,18 @@ class FakeTestCase(object): assert x != y def assertRaises(self,error,func,*p,**k): assert raises(error,func,*p,**k) + + def assertEqualsIgnoreWhitespace(self, v1, v2): + def createLines(s): + s = s.replace(' ', '') + lines = s.split('\n') + return [ line for line in lines if line ] + lines1 = createLines(v1) + lines2 = createLines(v2) + self.assertEquals(len(lines1), len(lines2)) + for line1, line2 in zip(lines1, lines2): + self.assertEquals(line1, line2) + class Base(FakeTestCase): """Base class for other test cases""" diff --git a/test/fixture/shell.py b/test/fixture/shell.py index 9a262eb..a07012a 100644 --- a/test/fixture/shell.py +++ b/test/fixture/shell.py @@ -16,7 +16,7 @@ class Shell(Pathed): return fd def output_and_exitcode(self,*p,**k): fd=self.execute(*p,**k) - output = fd.read() + output = fd.read().strip() exitcode = fd.close() if k.pop('emit',False): print output diff --git a/test/versioning/test_schemadiff.py b/test/versioning/test_schemadiff.py index a381d67..a8571e8 100644 --- a/test/versioning/test_schemadiff.py +++ b/test/versioning/test_schemadiff.py @@ -9,19 +9,6 @@ class TestSchemaDiff(fixture.DB): level=fixture.DB.CONNECT table_name = 'tmp_schemadiff' - def assertEqualsIgnoreWhitespace(self, v1, v2): - - def createLines(s): - s = s.replace(' ', '') - lines = s.split('\n') - return [ line for line in lines if line ] - - lines1 = createLines(v1) - lines2 = createLines(v2) - self.assertEquals(len(lines1), len(lines2)) - for line1, line2 in zip(lines1, lines2): - self.assertEquals(line1, line2) - def setUp(self): fixture.DB.setUp(self) self._connect(self.url) @@ -39,7 +26,9 @@ class TestSchemaDiff(fixture.DB): def tearDown(self): if self.table.exists(): - self.table.drop() + #self.table.drop() # bummer, this doesn't work because the list of tables is out of date, but calling reflect didn't work + self.meta = MetaData(self.engine, reflect=True) + self.meta.drop_all() fixture.DB.tearDown(self) def _applyLatestModel(self): diff --git a/test/versioning/test_shell.py b/test/versioning/test_shell.py index 68335cc..057741b 100644 --- a/test/versioning/test_shell.py +++ b/test/versioning/test_shell.py @@ -4,7 +4,7 @@ from StringIO import StringIO import os,shutil from test import fixture from migrate.versioning.repository import Repository -from migrate.versioning import shell +from migrate.versioning import genmodel, shell from sqlalchemy import MetaData,Table python_version = sys.version[0:3] @@ -454,3 +454,99 @@ class TestShellDatabase(Shell,fixture.DB): self.assertSuccess(self.cmd('test',script_path,repos_path,self.url)) self.assertEquals(self.cmd_version(repos_path),0) self.assertEquals(self.cmd_db_version(self.url,repos_path),0) + + @fixture.usedb() + def test_rundiffs_in_shell(self): + # This is a variant of the test_schemadiff tests but run through the shell level. + # These shell tests are hard to debug (since they keep forking processes), so they shouldn't replace the lower-level tests. + repos_name = 'repos_name' + repos_path = self.tmp() + script_path = self.tmp_py() + old_model_path = self.tmp_named('oldtestmodel.py') + model_path = self.tmp_named('testmodel.py') + + # Create empty repository. + self.assertSuccess(self.cmd('create',repos_path,repos_name)) + self.assertSuccess(self.cmd('version_control',self.url,repos_path)) + self.assertEquals(self.cmd_version(repos_path),0) + self.assertEquals(self.cmd_db_version(self.url,repos_path),0) + + # Setup helper script. + model_module = 'testmodel.meta' + self.assertSuccess(self.cmd('manage',script_path,'--repository=%s --url=%s --model=%s' % (repos_path, self.url, model_module))) + self.assert_(os.path.exists(script_path)) + + # Write old and new model to disk - old model is empty! + script_preamble=""" + from sqlalchemy import * + + meta = MetaData() + """.replace("\n ","\n") + + script_text=""" + """.replace("\n ","\n") + open(old_model_path, 'w').write(script_preamble + script_text) + + script_text=""" + account = Table('account',meta, + Column('id',Integer,primary_key=True), + Column('login',String(40)), + Column('passwd',String(40)), + ) + """.replace("\n ","\n") + open(model_path, 'w').write(script_preamble + script_text) + + # Model is defined but database is empty. + output, exitcode = self.output_and_exitcode('python %s compare_model_to_db' % script_path) + self.assertEquals(output, "Schema diffs:\n tables missing in database: account") + + # Update db to latest model. + output, exitcode = self.output_and_exitcode('python %s update_db_from_model' % script_path) + self.assertEquals(output, "") + output, exitcode = self.output_and_exitcode('python %s compare_model_to_db' % script_path) + self.assertEquals(output, "No schema diffs") + output, exitcode = self.output_and_exitcode('python %s create_model' % script_path) + output = output.replace(genmodel.HEADER.strip(), '') # need strip b/c output_and_exitcode called strip + self.assertEqualsIgnoreWhitespace(output, """ + account = Table('account',meta, + Column('id',Integer(),primary_key=True,nullable=False), + Column('login',String(length=None,convert_unicode=False,assert_unicode=None)), + Column('passwd',String(length=None,convert_unicode=False,assert_unicode=None)), + ) + """) # TODO: length shouldn't be None above + + # We're happy with db changes, make first db upgrade script to go from version 0 -> 1. + output, exitcode = self.output_and_exitcode('python %s make_update_script_for_model' % script_path) # intentionally omit a parameter + self.assertEquals('Error: Too few arguments' in output, True) + output, exitcode = self.output_and_exitcode('python %s make_update_script_for_model --oldmodel=oldtestmodel.meta' % script_path) + self.assertEqualsIgnoreWhitespace(output, """ + from sqlalchemy import * + from migrate import * + + meta = MetaData(migrate_engine) + account = Table('account', meta, + Column('id', Integer() , primary_key=True, nullable=False), + Column('login', String(length=40, convert_unicode=False, assert_unicode=None) ), + Column('passwd', String(length=40, convert_unicode=False, assert_unicode=None) ), + ) + + def upgrade(): + # Upgrade operations go here. Don't create your own engine; use the engine + # named 'migrate_engine' imported from migrate. + account.create() + + def downgrade(): + # Operations to reverse the above upgrade go here. + pass + """) + + # Commit the change. + upgrade_script_path = self.tmp_named('upgrade_script.py') + open(upgrade_script_path, 'w').write(output) + #output, exitcode = self.output_and_exitcode('python %s test %s' % (script_path, upgrade_script_path)) # no, we already upgraded the db above + #self.assertEquals(output, "") + output, exitcode = self.output_and_exitcode('python %s commit %s' % (script_path, upgrade_script_path)) + self.assertEquals(output, "") + self.assertEquals(self.cmd_version(repos_path),1) + #self.assertEquals(self.cmd_db_version(self.url,repos_path),1) TODO finish + |