diff options
author | iElectric <unknown> | 2009-06-03 00:48:59 +0000 |
---|---|---|
committer | iElectric <unknown> | 2009-06-03 00:48:59 +0000 |
commit | 4015cbf7e9d1fb845f8c4a2d0d3dc3ad22738e8a (patch) | |
tree | 4cb4e02e4357698168b7d1c4be03751b9a55997f /migrate | |
parent | c6883c0d472df8007db1f6c9a9930d1c53c717ae (diff) | |
download | sqalchemy-migrate-4015cbf7e9d1fb845f8c4a2d0d3dc3ad22738e8a.tar.gz |
Issue 38; add ability to pass arguments/dict for create_engine func
Diffstat (limited to 'migrate')
-rw-r--r-- | migrate/versioning/api.py | 189 | ||||
-rw-r--r-- | migrate/versioning/util/__init__.py | 20 |
2 files changed, 129 insertions, 80 deletions
diff --git a/migrate/versioning/api.py b/migrate/versioning/api.py index 80ee842..92e8bdb 100644 --- a/migrate/versioning/api.py +++ b/migrate/versioning/api.py @@ -14,12 +14,13 @@ import sys import inspect +import warnings from sqlalchemy import create_engine from migrate.versioning import (exceptions, repository, schema, version, script as script_) # command name conflict -from migrate.versioning.util import asbool, catch_known_errors +from migrate.versioning.util import asbool, catch_known_errors, guess_obj_type __all__ = [ 'help', @@ -41,10 +42,10 @@ __all__ = [ 'update_db_from_model', ] -cls_repository = repository.Repository -cls_schema = schema.ControlledSchema -cls_vernum = version.VerNum -cls_script_python = script_.PythonScript +Repository = repository.Repository +ControlledSchema = schema.ControlledSchema +VerNum = version.VerNum +PythonScript = script_.PythonScript # deprecated @@ -75,7 +76,7 @@ def create(repository, name, **opts): 'migrate_version'. This table is created in all version-controlled databases. """ - rep = cls_repository.create(repository, name, **opts) + repo_path = Repository.create(repository, name, **opts) @catch_known_errors @@ -88,8 +89,8 @@ def script(description, repository, **opts): For instance, manage.py script "Add initial tables" creates: repository/versions/001_Add_initial_tables.py """ - repos = cls_repository(repository) - repos.create_script(description, **opts) + repo = Repository(repository) + repo.create_script(description, **opts) @catch_known_errors @@ -104,8 +105,8 @@ def script_sql(database, repository, **opts): repository/versions/001_postgres_upgrade.sql and repository/versions/001_postgres_postgres.sql """ - repos = cls_repository(repository) - repos.create_script_sql(database, **opts) + repo = Repository(repository) + repo.create_script_sql(database, **opts) def test(repository, url=None, **opts): @@ -116,8 +117,8 @@ def test(repository, url=None, **opts): bad state. You should therefore better run the test on a copy of your database. """ - engine = create_engine(url) - repos = cls_repository(repository) + engine = _construct_engine(url, **opts) + repos = Repository(repository) script = repos.version(None).script() # Upgrade @@ -136,8 +137,8 @@ def version(repository, **opts): Display the latest version available in a repository. """ - repos = cls_repository(repository) - return repos.latest + repo = Repository(repository) + return repo.latest def source(version, dest=None, repository=None, **opts): @@ -149,8 +150,8 @@ def source(version, dest=None, repository=None, **opts): """ if repository is None: raise exceptions.UsageError("A repository must be specified") - repos = cls_repository(repository) - ret = repos.version(version).script().source() + repo = Repository(repository) + ret = repo.version(version).script().source() if dest is not None: dest = open(dest, 'w') dest.write(ret) @@ -178,9 +179,8 @@ def version_control(url, repository, version=None, **opts): identical to what it would be if the database were created from scratch. """ - echo = asbool(opts.get('echo', False)) - engine = create_engine(url, echo=echo) - cls_schema.create(engine, repository, version) + engine = _construct_engine(url, **opts) + ControlledSchema.create(engine, repository, version) def db_version(url, repository, **opts): @@ -192,9 +192,8 @@ def db_version(url, repository, **opts): The url should be any valid SQLAlchemy connection string. """ - echo = asbool(opts.get('echo', False)) - engine = create_engine(url, echo=echo) - schema = cls_schema(engine, repository) + engine = _construct_engine(url, **opts) + schema = ControlledSchema(engine, repository) return schema.version @@ -230,58 +229,15 @@ def downgrade(url, repository, version, **opts): err = "Cannot downgrade a database of version %s to version %s. "\ "Try 'upgrade' instead." return _migrate(url, repository, version, upgrade=False, err=err, **opts) - - -def _migrate(url, repository, version, upgrade, err, **opts): - echo = asbool(opts.get('echo', False)) - engine = create_engine(url, echo=echo) - schema = cls_schema(engine, repository) - version = _migrate_version(schema, version, upgrade, err) - - changeset = schema.changeset(version) - for ver, change in changeset: - nextver = ver + changeset.step - print '%s -> %s... '%(ver, nextver), - if opts.get('preview_sql'): - print - print change.log - elif opts.get('preview_py'): - source_ver = max(ver, nextver) - module = schema.repository.version(source_ver).script().module - funcname = upgrade and "upgrade" or "downgrade" - func = getattr(module, funcname) - print - print inspect.getsource(module.upgrade) - else: - schema.runchange(ver, change, changeset.step) - print 'done' - - -def _migrate_version(schema, version, upgrade, err): - if version is None: - return version - # Version is specified: ensure we're upgrading in the right direction - # (current version < target version for upgrading; reverse for down) - version = cls_vernum(version) - cur = schema.version - if upgrade is not None: - if upgrade: - direction = cur <= version - else: - direction = cur >= version - if not direction: - raise exceptions.KnownError(err%(cur, version)) - return version - + def drop_version_control(url, repository, **opts): """%prog drop_version_control URL REPOSITORY_PATH Removes version control from a database. """ - echo = asbool(opts.get('echo', False)) - engine = create_engine(url, echo=echo) - schema = cls_schema(engine, repository) + engine = _construct_engine(url, **opts) + schema = ControlledSchema(engine, repository) schema.drop() @@ -312,9 +268,8 @@ def compare_model_to_db(url, model, repository, **opts): NOTE: This is EXPERIMENTAL. """ # TODO: get rid of EXPERIMENTAL label - echo = asbool(opts.get('echo', False)) - engine = create_engine(url, echo=echo) - print cls_schema.compare_model_to_db(engine, model, repository) + engine = _construct_engine(url, **opts) + print ControlledSchema.compare_model_to_db(engine, model, repository) def create_model(url, repository, **opts): @@ -324,10 +279,9 @@ def create_model(url, repository, **opts): NOTE: This is EXPERIMENTAL. """ # TODO: get rid of EXPERIMENTAL label - echo = asbool(opts.get('echo', False)) - engine = create_engine(url, echo=echo) + engine = _construct_engine(url, **opts) declarative = opts.get('declarative', False) - print cls_schema.create_model(engine, repository, declarative) + print ControlledSchema.create_model(engine, repository, declarative) # TODO: get rid of this? if we don't add back path param @@ -340,9 +294,8 @@ def make_update_script_for_model(url, oldmodel, model, repository, **opts): NOTE: This is EXPERIMENTAL. """ # TODO: get rid of EXPERIMENTAL label - echo = asbool(opts.get('echo', False)) - engine = create_engine(url, echo=echo) - print cls_script_python.make_update_script_for_model( + engine = _construct_engine(url, **opts) + print PythonScript.make_update_script_for_model( engine, oldmodel, model, repository, **opts) @@ -355,7 +308,83 @@ def update_db_from_model(url, model, repository, **opts): NOTE: This is EXPERIMENTAL. """ # TODO: get rid of EXPERIMENTAL label - echo = asbool(opts.get('echo', False)) - engine = create_engine(url, echo=echo) - schema = cls_schema(engine, repository) + engine = _construct_engine(url, **opts) + schema = ControlledSchema(engine, repository) schema.update_db_from_model(model) + + +def _migrate(url, repository, version, upgrade, err, **opts): + engine = _construct_engine(url, **opts) + schema = ControlledSchema(engine, repository) + version = _migrate_version(schema, version, upgrade, err) + + changeset = schema.changeset(version) + for ver, change in changeset: + nextver = ver + changeset.step + print '%s -> %s... '%(ver, nextver), + if opts.get('preview_sql'): + print + print change.log + elif opts.get('preview_py'): + source_ver = max(ver, nextver) + module = schema.repository.version(source_ver).script().module + funcname = upgrade and "upgrade" or "downgrade" + func = getattr(module, funcname) + print + print inspect.getsource(module.upgrade) + else: + schema.runchange(ver, change, changeset.step) + print 'done' + + +def _migrate_version(schema, version, upgrade, err): + if version is None: + return version + # Version is specified: ensure we're upgrading in the right direction + # (current version < target version for upgrading; reverse for down) + version = VerNum(version) + cur = schema.version + if upgrade is not None: + if upgrade: + direction = cur <= version + else: + direction = cur >= version + if not direction: + raise exceptions.KnownError(err % (cur, version)) + return version + + +def _construct_engine(url, **opts): + """Constructs and returns SQLAlchemy engine. + + Currently, there are 2 ways to pass create_engine options to api functions: + + * keyword parameters (starting with `engine_arg_*`) + * python dictionary of options (`engine_dict`) + + NOTE: keyword parameters override `engine_dict` values. + + .. versionadded:: 0.5.4 + """ + # TODO: include docs + + # get options for create_engine + if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict): + kwargs = opts['engine_dict'] + else: + kwargs = dict() + + # DEPRECATED: handle echo the old way + echo = asbool(opts.get('echo', False)) + if echo: + warnings.warn('echo=True parameter is deprecated, pass ' + 'engine_arg_echo=True or engine_dict={"echo": True}', + DeprecationWarning) + kwargs['echo'] = echo + + # parse keyword arguments + for key, value in opts.iteritems(): + if key.startswith('engine_arg_'): + kwargs[key[11:]] = guess_obj_type(value) + + return create_engine(url, **kwargs) diff --git a/migrate/versioning/util/__init__.py b/migrate/versioning/util/__init__.py index 9942f82..60d190f 100644 --- a/migrate/versioning/util/__init__.py +++ b/migrate/versioning/util/__init__.py @@ -31,6 +31,26 @@ def asbool(obj): raise ValueError("String is not true/false: %r" % obj) return bool(obj) +def guess_obj_type(obj): + """Do everything to guess object type from string""" + result = None + + try: + result = asbool(obj) + except: + pass + + if result is None: + try: + result = int(obj) + except: + pass + + if result is not None: + return result + else: + return obj + @decorator def catch_known_errors(f, *a, **kw): """Decorator that catches known api usage errors""" |