summaryrefslogtreecommitdiff
path: root/migrate/versioning/version.py
diff options
context:
space:
mode:
authorPete Keen <pete@bugsplat.info>2011-06-10 13:56:30 -0700
committerPete Keen <pete@bugsplat.info>2011-06-10 13:56:30 -0700
commitb78a249b5207e3e0ab2df3b469c3af48acfe5ff7 (patch)
tree60e809542ce2e7cb729de35dc809b7f1795ab2e1 /migrate/versioning/version.py
parentcbebf76ade778042e3b38f3b14455fec6b7e3bee (diff)
downloadsqlalchemy-migrate-b78a249b5207e3e0ab2df3b469c3af48acfe5ff7.tar.gz
Allow descriptions in sql change script filenames
Diffstat (limited to 'migrate/versioning/version.py')
-rw-r--r--migrate/versioning/version.py25
1 files changed, 19 insertions, 6 deletions
diff --git a/migrate/versioning/version.py b/migrate/versioning/version.py
index fdb78a9..f41a71c 100644
--- a/migrate/versioning/version.py
+++ b/migrate/versioning/version.py
@@ -114,14 +114,22 @@ class Collection(pathed.Pathed):
script.PythonScript.create(filepath, **k)
self.versions[ver] = Version(ver, self.path, [filename])
- def create_new_sql_version(self, database, **k):
+ def create_new_sql_version(self, database, description, **k):
"""Create SQL files for new version"""
ver = self._next_ver_num(k.pop('use_timestamp_numbering', False))
self.versions[ver] = Version(ver, self.path, [])
+ extra = str_to_filename(description)
+
+ if extra:
+ if extra == '_':
+ extra = ''
+ elif not extra.startswith('_'):
+ extra = '_%s' % extra
+
# Create new files.
for op in ('upgrade', 'downgrade'):
- filename = '%03d_%s_%s.sql' % (ver, database, op)
+ filename = '%03d%s_%s_%s.sql' % (ver, extra, database, op)
filepath = self._version_path(filename)
script.SqlScript.create(filepath, **k)
self.versions[ver].add_script(filepath)
@@ -185,18 +193,23 @@ class Version(object):
elif path.endswith(Extensions.sql):
self._add_script_sql(path)
- SQL_FILENAME = re.compile(r'^(\d+)_([^_]+)_([^_]+).sql')
+ SQL_FILENAME = re.compile(r'^.*\.sql')
def _add_script_sql(self, path):
basename = os.path.basename(path)
match = self.SQL_FILENAME.match(basename)
-
+
if match:
- version, dbms, op = match.group(1), match.group(2), match.group(3)
+ basename = basename.replace('.sql', '')
+ parts = basename.split('_')
+ assert len(parts) >= 3
+ version = parts[0]
+ op = parts[-1]
+ dbms = parts[-2]
else:
raise exceptions.ScriptError(
"Invalid SQL script name %s " % basename + \
- "(needs to be ###_database_operation.sql)")
+ "(needs to be ###_description_database_operation.sql)")
# File the script into a dictionary
self.sql.setdefault(dbms, {})[op] = script.SqlScript(path)