summaryrefslogtreecommitdiff
path: root/migrate
diff options
context:
space:
mode:
authoriElectric <unknown>2009-06-03 00:48:59 +0000
committeriElectric <unknown>2009-06-03 00:48:59 +0000
commit4015cbf7e9d1fb845f8c4a2d0d3dc3ad22738e8a (patch)
tree4cb4e02e4357698168b7d1c4be03751b9a55997f /migrate
parentc6883c0d472df8007db1f6c9a9930d1c53c717ae (diff)
downloadsqalchemy-migrate-4015cbf7e9d1fb845f8c4a2d0d3dc3ad22738e8a.tar.gz
Issue 38; add ability to pass arguments/dict for create_engine func
Diffstat (limited to 'migrate')
-rw-r--r--migrate/versioning/api.py189
-rw-r--r--migrate/versioning/util/__init__.py20
2 files changed, 129 insertions, 80 deletions
diff --git a/migrate/versioning/api.py b/migrate/versioning/api.py
index 80ee842..92e8bdb 100644
--- a/migrate/versioning/api.py
+++ b/migrate/versioning/api.py
@@ -14,12 +14,13 @@
import sys
import inspect
+import warnings
from sqlalchemy import create_engine
from migrate.versioning import (exceptions, repository, schema, version,
script as script_) # command name conflict
-from migrate.versioning.util import asbool, catch_known_errors
+from migrate.versioning.util import asbool, catch_known_errors, guess_obj_type
__all__ = [
'help',
@@ -41,10 +42,10 @@ __all__ = [
'update_db_from_model',
]
-cls_repository = repository.Repository
-cls_schema = schema.ControlledSchema
-cls_vernum = version.VerNum
-cls_script_python = script_.PythonScript
+Repository = repository.Repository
+ControlledSchema = schema.ControlledSchema
+VerNum = version.VerNum
+PythonScript = script_.PythonScript
# deprecated
@@ -75,7 +76,7 @@ def create(repository, name, **opts):
'migrate_version'. This table is created in all version-controlled
databases.
"""
- rep = cls_repository.create(repository, name, **opts)
+ repo_path = Repository.create(repository, name, **opts)
@catch_known_errors
@@ -88,8 +89,8 @@ def script(description, repository, **opts):
For instance, manage.py script "Add initial tables" creates:
repository/versions/001_Add_initial_tables.py
"""
- repos = cls_repository(repository)
- repos.create_script(description, **opts)
+ repo = Repository(repository)
+ repo.create_script(description, **opts)
@catch_known_errors
@@ -104,8 +105,8 @@ def script_sql(database, repository, **opts):
repository/versions/001_postgres_upgrade.sql and
repository/versions/001_postgres_postgres.sql
"""
- repos = cls_repository(repository)
- repos.create_script_sql(database, **opts)
+ repo = Repository(repository)
+ repo.create_script_sql(database, **opts)
def test(repository, url=None, **opts):
@@ -116,8 +117,8 @@ def test(repository, url=None, **opts):
bad state. You should therefore better run the test on a copy of
your database.
"""
- engine = create_engine(url)
- repos = cls_repository(repository)
+ engine = _construct_engine(url, **opts)
+ repos = Repository(repository)
script = repos.version(None).script()
# Upgrade
@@ -136,8 +137,8 @@ def version(repository, **opts):
Display the latest version available in a repository.
"""
- repos = cls_repository(repository)
- return repos.latest
+ repo = Repository(repository)
+ return repo.latest
def source(version, dest=None, repository=None, **opts):
@@ -149,8 +150,8 @@ def source(version, dest=None, repository=None, **opts):
"""
if repository is None:
raise exceptions.UsageError("A repository must be specified")
- repos = cls_repository(repository)
- ret = repos.version(version).script().source()
+ repo = Repository(repository)
+ ret = repo.version(version).script().source()
if dest is not None:
dest = open(dest, 'w')
dest.write(ret)
@@ -178,9 +179,8 @@ def version_control(url, repository, version=None, **opts):
identical to what it would be if the database were created from
scratch.
"""
- echo = asbool(opts.get('echo', False))
- engine = create_engine(url, echo=echo)
- cls_schema.create(engine, repository, version)
+ engine = _construct_engine(url, **opts)
+ ControlledSchema.create(engine, repository, version)
def db_version(url, repository, **opts):
@@ -192,9 +192,8 @@ def db_version(url, repository, **opts):
The url should be any valid SQLAlchemy connection string.
"""
- echo = asbool(opts.get('echo', False))
- engine = create_engine(url, echo=echo)
- schema = cls_schema(engine, repository)
+ engine = _construct_engine(url, **opts)
+ schema = ControlledSchema(engine, repository)
return schema.version
@@ -230,58 +229,15 @@ def downgrade(url, repository, version, **opts):
err = "Cannot downgrade a database of version %s to version %s. "\
"Try 'upgrade' instead."
return _migrate(url, repository, version, upgrade=False, err=err, **opts)
-
-
-def _migrate(url, repository, version, upgrade, err, **opts):
- echo = asbool(opts.get('echo', False))
- engine = create_engine(url, echo=echo)
- schema = cls_schema(engine, repository)
- version = _migrate_version(schema, version, upgrade, err)
-
- changeset = schema.changeset(version)
- for ver, change in changeset:
- nextver = ver + changeset.step
- print '%s -> %s... '%(ver, nextver),
- if opts.get('preview_sql'):
- print
- print change.log
- elif opts.get('preview_py'):
- source_ver = max(ver, nextver)
- module = schema.repository.version(source_ver).script().module
- funcname = upgrade and "upgrade" or "downgrade"
- func = getattr(module, funcname)
- print
- print inspect.getsource(module.upgrade)
- else:
- schema.runchange(ver, change, changeset.step)
- print 'done'
-
-
-def _migrate_version(schema, version, upgrade, err):
- if version is None:
- return version
- # Version is specified: ensure we're upgrading in the right direction
- # (current version < target version for upgrading; reverse for down)
- version = cls_vernum(version)
- cur = schema.version
- if upgrade is not None:
- if upgrade:
- direction = cur <= version
- else:
- direction = cur >= version
- if not direction:
- raise exceptions.KnownError(err%(cur, version))
- return version
-
+
def drop_version_control(url, repository, **opts):
"""%prog drop_version_control URL REPOSITORY_PATH
Removes version control from a database.
"""
- echo = asbool(opts.get('echo', False))
- engine = create_engine(url, echo=echo)
- schema = cls_schema(engine, repository)
+ engine = _construct_engine(url, **opts)
+ schema = ControlledSchema(engine, repository)
schema.drop()
@@ -312,9 +268,8 @@ def compare_model_to_db(url, model, repository, **opts):
NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label
- echo = asbool(opts.get('echo', False))
- engine = create_engine(url, echo=echo)
- print cls_schema.compare_model_to_db(engine, model, repository)
+ engine = _construct_engine(url, **opts)
+ print ControlledSchema.compare_model_to_db(engine, model, repository)
def create_model(url, repository, **opts):
@@ -324,10 +279,9 @@ def create_model(url, repository, **opts):
NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label
- echo = asbool(opts.get('echo', False))
- engine = create_engine(url, echo=echo)
+ engine = _construct_engine(url, **opts)
declarative = opts.get('declarative', False)
- print cls_schema.create_model(engine, repository, declarative)
+ print ControlledSchema.create_model(engine, repository, declarative)
# TODO: get rid of this? if we don't add back path param
@@ -340,9 +294,8 @@ def make_update_script_for_model(url, oldmodel, model, repository, **opts):
NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label
- echo = asbool(opts.get('echo', False))
- engine = create_engine(url, echo=echo)
- print cls_script_python.make_update_script_for_model(
+ engine = _construct_engine(url, **opts)
+ print PythonScript.make_update_script_for_model(
engine, oldmodel, model, repository, **opts)
@@ -355,7 +308,83 @@ def update_db_from_model(url, model, repository, **opts):
NOTE: This is EXPERIMENTAL.
""" # TODO: get rid of EXPERIMENTAL label
- echo = asbool(opts.get('echo', False))
- engine = create_engine(url, echo=echo)
- schema = cls_schema(engine, repository)
+ engine = _construct_engine(url, **opts)
+ schema = ControlledSchema(engine, repository)
schema.update_db_from_model(model)
+
+
+def _migrate(url, repository, version, upgrade, err, **opts):
+ engine = _construct_engine(url, **opts)
+ schema = ControlledSchema(engine, repository)
+ version = _migrate_version(schema, version, upgrade, err)
+
+ changeset = schema.changeset(version)
+ for ver, change in changeset:
+ nextver = ver + changeset.step
+ print '%s -> %s... '%(ver, nextver),
+ if opts.get('preview_sql'):
+ print
+ print change.log
+ elif opts.get('preview_py'):
+ source_ver = max(ver, nextver)
+ module = schema.repository.version(source_ver).script().module
+ funcname = upgrade and "upgrade" or "downgrade"
+ func = getattr(module, funcname)
+ print
+ print inspect.getsource(module.upgrade)
+ else:
+ schema.runchange(ver, change, changeset.step)
+ print 'done'
+
+
+def _migrate_version(schema, version, upgrade, err):
+ if version is None:
+ return version
+ # Version is specified: ensure we're upgrading in the right direction
+ # (current version < target version for upgrading; reverse for down)
+ version = VerNum(version)
+ cur = schema.version
+ if upgrade is not None:
+ if upgrade:
+ direction = cur <= version
+ else:
+ direction = cur >= version
+ if not direction:
+ raise exceptions.KnownError(err % (cur, version))
+ return version
+
+
+def _construct_engine(url, **opts):
+ """Constructs and returns SQLAlchemy engine.
+
+ Currently, there are 2 ways to pass create_engine options to api functions:
+
+ * keyword parameters (starting with `engine_arg_*`)
+ * python dictionary of options (`engine_dict`)
+
+ NOTE: keyword parameters override `engine_dict` values.
+
+ .. versionadded:: 0.5.4
+ """
+ # TODO: include docs
+
+ # get options for create_engine
+ if opts.get('engine_dict') and isinstance(opts['engine_dict'], dict):
+ kwargs = opts['engine_dict']
+ else:
+ kwargs = dict()
+
+ # DEPRECATED: handle echo the old way
+ echo = asbool(opts.get('echo', False))
+ if echo:
+ warnings.warn('echo=True parameter is deprecated, pass '
+ 'engine_arg_echo=True or engine_dict={"echo": True}',
+ DeprecationWarning)
+ kwargs['echo'] = echo
+
+ # parse keyword arguments
+ for key, value in opts.iteritems():
+ if key.startswith('engine_arg_'):
+ kwargs[key[11:]] = guess_obj_type(value)
+
+ return create_engine(url, **kwargs)
diff --git a/migrate/versioning/util/__init__.py b/migrate/versioning/util/__init__.py
index 9942f82..60d190f 100644
--- a/migrate/versioning/util/__init__.py
+++ b/migrate/versioning/util/__init__.py
@@ -31,6 +31,26 @@ def asbool(obj):
raise ValueError("String is not true/false: %r" % obj)
return bool(obj)
+def guess_obj_type(obj):
+ """Do everything to guess object type from string"""
+ result = None
+
+ try:
+ result = asbool(obj)
+ except:
+ pass
+
+ if result is None:
+ try:
+ result = int(obj)
+ except:
+ pass
+
+ if result is not None:
+ return result
+ else:
+ return obj
+
@decorator
def catch_known_errors(f, *a, **kw):
"""Decorator that catches known api usage errors"""