diff options
| author | Pete Keen <pete@bugsplat.info> | 2011-06-10 13:56:30 -0700 |
|---|---|---|
| committer | Pete Keen <pete@bugsplat.info> | 2011-06-10 13:56:30 -0700 |
| commit | b78a249b5207e3e0ab2df3b469c3af48acfe5ff7 (patch) | |
| tree | 60e809542ce2e7cb729de35dc809b7f1795ab2e1 /migrate/versioning/version.py | |
| parent | cbebf76ade778042e3b38f3b14455fec6b7e3bee (diff) | |
| download | sqlalchemy-migrate-b78a249b5207e3e0ab2df3b469c3af48acfe5ff7.tar.gz | |
Allow descriptions in sql change script filenames
Diffstat (limited to 'migrate/versioning/version.py')
| -rw-r--r-- | migrate/versioning/version.py | 25 |
1 files changed, 19 insertions, 6 deletions
diff --git a/migrate/versioning/version.py b/migrate/versioning/version.py index fdb78a9..f41a71c 100644 --- a/migrate/versioning/version.py +++ b/migrate/versioning/version.py @@ -114,14 +114,22 @@ class Collection(pathed.Pathed): script.PythonScript.create(filepath, **k) self.versions[ver] = Version(ver, self.path, [filename]) - def create_new_sql_version(self, database, **k): + def create_new_sql_version(self, database, description, **k): """Create SQL files for new version""" ver = self._next_ver_num(k.pop('use_timestamp_numbering', False)) self.versions[ver] = Version(ver, self.path, []) + extra = str_to_filename(description) + + if extra: + if extra == '_': + extra = '' + elif not extra.startswith('_'): + extra = '_%s' % extra + # Create new files. for op in ('upgrade', 'downgrade'): - filename = '%03d_%s_%s.sql' % (ver, database, op) + filename = '%03d%s_%s_%s.sql' % (ver, extra, database, op) filepath = self._version_path(filename) script.SqlScript.create(filepath, **k) self.versions[ver].add_script(filepath) @@ -185,18 +193,23 @@ class Version(object): elif path.endswith(Extensions.sql): self._add_script_sql(path) - SQL_FILENAME = re.compile(r'^(\d+)_([^_]+)_([^_]+).sql') + SQL_FILENAME = re.compile(r'^.*\.sql') def _add_script_sql(self, path): basename = os.path.basename(path) match = self.SQL_FILENAME.match(basename) - + if match: - version, dbms, op = match.group(1), match.group(2), match.group(3) + basename = basename.replace('.sql', '') + parts = basename.split('_') + assert len(parts) >= 3 + version = parts[0] + op = parts[-1] + dbms = parts[-2] else: raise exceptions.ScriptError( "Invalid SQL script name %s " % basename + \ - "(needs to be ###_database_operation.sql)") + "(needs to be ###_description_database_operation.sql)") # File the script into a dictionary self.sql.setdefault(dbms, {})[op] = script.SqlScript(path) |
