diff options
author | Olly Cope <olly@ollycope.com> | 2022-09-01 12:28:48 +0000 |
---|---|---|
committer | Olly Cope <olly@ollycope.com> | 2022-09-01 12:28:48 +0000 |
commit | 030c45f0d17144a681d56de85857eb5819e1c09d (patch) | |
tree | f053eeef1def3d38c4702a99ad2b6c4a058fe336 | |
parent | 290150383e2a12628ccc72c4f4fed252a36d0f40 (diff) | |
download | yoyo-030c45f0d17144a681d56de85857eb5819e1c09d.tar.gz |
Split yoyo.backends into a package
-rw-r--r-- | setup.cfg | 23 | ||||
-rw-r--r-- | yoyo/backends/__init__.py | 13 | ||||
-rw-r--r-- | yoyo/backends/base.py (renamed from yoyo/backends.py) | 238 | ||||
-rw-r--r-- | yoyo/backends/contrib/__init__.py | 0 | ||||
-rw-r--r-- | yoyo/backends/contrib/odbc.py | 31 | ||||
-rw-r--r-- | yoyo/backends/contrib/oracle.py | 46 | ||||
-rw-r--r-- | yoyo/backends/contrib/redshift.py | 62 | ||||
-rw-r--r-- | yoyo/backends/contrib/snowflake.py | 40 | ||||
-rw-r--r-- | yoyo/backends/core/__init__.py | 9 | ||||
-rw-r--r-- | yoyo/backends/core/mysql.py | 68 | ||||
-rw-r--r-- | yoyo/backends/core/postgresql.py | 58 | ||||
-rw-r--r-- | yoyo/backends/core/sqlite3.py | 32 | ||||
-rw-r--r-- | yoyo/tests/conftest.py | 4 | ||||
-rw-r--r-- | yoyo/tests/test_backends.py | 3 | ||||
-rw-r--r-- | yoyo/tests/test_connections.py | 8 |
15 files changed, 386 insertions, 249 deletions
@@ -23,6 +23,9 @@ classifiers = [options] packages = yoyo + yoyo.backends + yoyo.backends.core + yoyo.backends.contrib yoyo.scripts yoyo.internalmigrations install_requires = @@ -44,13 +47,13 @@ console_scripts = yoyo-migrate = yoyo.scripts.main:main yoyo.backends = - odbc = yoyo.backends:ODBCBackend - oracle = yoyo.backends:OracleBackend - postgresql = yoyo.backends:PostgresqlBackend - postgres = yoyo.backends:PostgresqlBackend - psql = yoyo.backends:PostgresqlBackend - mysql = yoyo.backends:MySQLBackend - mysql+mysqldb = yoyo.backends:MySQLdbBackend - sqlite = yoyo.backends:SQLiteBackend - snowflake = yoyo.backends:SnowflakeBackend - redshift = yoyo.backends:RedshiftBackend + odbc = yoyo.backends.contrib.odbc:ODBCBackend + oracle = yoyo.backends.contrib.oracle:OracleBackend + postgres = yoyo.backends.core.postgresql:PostgresqlBackend + postgresql = yoyo.backends.core.postgresql:PostgresqlBackend + psql = yoyo.backends.core.postgresql:PostgresqlBackend + mysql = yoyo.backends.core.mysql:MySQLBackend + mysql+mysqldb = yoyo.backends.core.mysql:MySQLdbBackend + sqlite = yoyo.backends.core.sqlite3:SQLiteBackend + snowflake = yoyo.backends.contrib.snowflake:SnowflakeBackend + redshift = yoyo.backends.contrib.redshift:RedshiftBackend diff --git a/yoyo/backends/__init__.py b/yoyo/backends/__init__.py new file mode 100644 index 0000000..c4d6f6e --- /dev/null +++ b/yoyo/backends/__init__.py @@ -0,0 +1,13 @@ +from yoyo.backends.base import DatabaseBackend +from yoyo.backends.base import get_backend_class +from yoyo.backends.core import MySQLBackend +from yoyo.backends.core import SQLiteBackend +from yoyo.backends.core import PostgresqlBackend + +__all__ = [ + "DatabaseBackend", + "get_backend_class", + "MySQLBackend", + "SQLiteBackend", + "PostgresqlBackend", +] diff --git a/yoyo/backends.py b/yoyo/backends/base.py index dc60130..f288d19 100644 --- a/yoyo/backends.py +++ b/yoyo/backends/base.py @@ -28,15 +28,15 @@ import socket import time import uuid -from . import exceptions -from . import internalmigrations -from . import utils -from .migrations import topological_sort +from yoyo import exceptions +from yoyo import internalmigrations +from yoyo import utils +from yoyo.migrations import topological_sort logger = getLogger("yoyo.migrations") -class TransactionManager(object): +class TransactionManager: """ Returned by the :meth:`~yoyo.backends.DatabaseBackend.transaction` context manager. @@ -112,7 +112,7 @@ class SavepointTransactionManager(TransactionManager): self.backend.savepoint_rollback(self.id) -class DatabaseBackend(object): +class DatabaseBackend: driver_module = "" # type: str @@ -566,232 +566,6 @@ class DatabaseBackend(object): } -class ODBCBackend(DatabaseBackend): - driver_module = "pyodbc" - - def connect(self, dburi): - args = [ - ("UID", dburi.username), - ("PWD", dburi.password), - ("ServerName", dburi.hostname), - ("Port", dburi.port), - ("Database", dburi.database), - ] - args.extend(dburi.args.items()) - s = ";".join("{}={}".format(k, v) for k, v in args if v is not None) - return self.driver.connect(s) - - -class OracleBackend(DatabaseBackend): - - driver_module = "cx_Oracle" - list_tables_sql = "SELECT table_name FROM all_tables WHERE owner=user" - - def begin(self): - """Oracle is always in a transaction, and has no "BEGIN" statement.""" - self._in_transaction = True - - def connect(self, dburi): - kwargs = dburi.args - if dburi.username is not None: - kwargs["user"] = dburi.username - if dburi.password is not None: - kwargs["password"] = dburi.password - # Oracle combines the hostname, port and database into a single DSN. - # The DSN can also be a "net service name" - kwargs["dsn"] = "" - if dburi.hostname is not None: - kwargs["dsn"] = dburi.hostname - if dburi.port is not None: - kwargs["dsn"] += ":{0}".format(dburi.port) - if dburi.database is not None: - if kwargs["dsn"]: - kwargs["dsn"] += "/{0}".format(dburi.database) - else: - kwargs["dsn"] = dburi.database - - return self.driver.connect(**kwargs) - - -class MySQLBackend(DatabaseBackend): - - driver_module = "pymysql" - list_tables_sql = ( - "SELECT table_name FROM information_schema.tables " - "WHERE table_schema = :database" - ) - - def connect(self, dburi): - kwargs = {"db": dburi.database} - kwargs.update(dburi.args) - if dburi.username is not None: - kwargs["user"] = dburi.username - if dburi.password is not None: - kwargs["passwd"] = dburi.password - if dburi.hostname is not None: - kwargs["host"] = dburi.hostname - if dburi.port is not None: - kwargs["port"] = dburi.port - if "unix_socket" in dburi.args: - kwargs["unix_socket"] = dburi.args["unix_socket"] - if "ssl" in dburi.args: - kwargs["ssl"] = {} - - if "sslca" in dburi.args: - kwargs["ssl"]["ca"] = dburi.args["sslca"] - - if "sslcapath" in dburi.args: - kwargs["ssl"]["capath"] = dburi.args["sslcapath"] - - if "sslcert" in dburi.args: - kwargs["ssl"]["cert"] = dburi.args["sslcert"] - - if "sslkey" in dburi.args: - kwargs["ssl"]["key"] = dburi.args["sslkey"] - - if "sslcipher" in dburi.args: - kwargs["ssl"]["cipher"] = dburi.args["sslcipher"] - - kwargs["db"] = dburi.database - return self.driver.connect(**kwargs) - - def quote_identifier(self, identifier): - sql_mode = self.execute("SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1] - if "ansi_quotes" in sql_mode.lower(): - return super(MySQLBackend, self).quote_identifier(identifier) - return "`{}`".format(identifier) - - -class MySQLdbBackend(MySQLBackend): - driver_module = "MySQLdb" - - -class SQLiteBackend(DatabaseBackend): - - driver_module = "sqlite3" - list_tables_sql = "SELECT name FROM sqlite_master WHERE type = 'table'" - - def connect(self, dburi): - # Ensure that multiple connections share the same data - # https://sqlite.org/sharedcache.html - conn = self.driver.connect( - f"file:{dburi.database}?cache=shared", - uri=True, - detect_types=self.driver.PARSE_DECLTYPES, - ) - conn.isolation_level = None - return conn - - -class PostgresqlBackend(DatabaseBackend): - - driver_module = "psycopg2" - schema = None - list_tables_sql = ( - "SELECT table_name FROM information_schema.tables " - "WHERE table_schema = :schema" - ) - - def connect(self, dburi): - kwargs = {"dbname": dburi.database} - kwargs.update(dburi.args) - if dburi.username is not None: - kwargs["user"] = dburi.username - if dburi.password is not None: - kwargs["password"] = dburi.password - if dburi.port is not None: - kwargs["port"] = dburi.port - if dburi.hostname is not None: - kwargs["host"] = dburi.hostname - self.schema = kwargs.pop("schema", None) - return self.driver.connect(**kwargs) - - @contextmanager - def disable_transactions(self): - with super(PostgresqlBackend, self).disable_transactions(): - saved = self.connection.autocommit - self.connection.autocommit = True - yield - self.connection.autocommit = saved - - def init_connection(self, connection): - if self.schema: - cursor = connection.cursor() - cursor.execute("SET search_path TO {}".format(self.schema)) - - def list_tables(self): - current_schema = self.execute("SELECT current_schema").fetchone()[0] - return super(PostgresqlBackend, self).list_tables(schema=current_schema) - - -class RedshiftBackend(PostgresqlBackend): - def list_tables(self): - current_schema = self.execute("SELECT current_schema()").fetchone()[0] - return super(PostgresqlBackend, self).list_tables(schema=current_schema) - - # Redshift does not support ROLLBACK TO SAVEPOINT - def savepoint(self, id): - pass - - def savepoint_release(self, id): - pass - - def savepoint_rollback(self, id): - self.rollback() - - # Redshift does not enforce primary and unique keys - def _insert_lock_row(self, pid, timeout, poll_interval=0.5): - poll_interval = min(poll_interval, timeout) - started = time.time() - while True: - with self.transaction(): - # prevents isolation violation errors - self.execute("LOCK {}".format(self.lock_table_quoted)) - cursor = self.execute( - "SELECT pid FROM {}".format(self.lock_table_quoted) - ) - row = cursor.fetchone() - if not row: - self.execute( - "INSERT INTO {} (locked, ctime, pid) " - "VALUES (1, :when, :pid)".format(self.lock_table_quoted), - {"when": datetime.utcnow(), "pid": pid}, - ) - return - elif timeout and time.time() > started + timeout: - raise exceptions.LockTimeout( - "Process {} has locked this database " - "(run yoyo break-lock to remove this lock)".format(row[0]) - ) - else: - time.sleep(poll_interval) - - -class SnowflakeBackend(DatabaseBackend): - - driver_module = "snowflake.connector" - - def connect(self, dburi): - database, schema = dburi.database.split("/") - return self.driver.connect( - user=dburi.username, - password=dburi.password, - account=dburi.hostname, - database=database, - schema=schema, - **dburi.args, - ) - - def savepoint(self, id): - pass - - def savepoint_release(self, id): - pass - - def savepoint_rollback(self, id): - pass - - def get_backend_class(name): backend_eps = entry_points(group="yoyo.backends") return backend_eps[name].load() diff --git a/yoyo/backends/contrib/__init__.py b/yoyo/backends/contrib/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/yoyo/backends/contrib/__init__.py diff --git a/yoyo/backends/contrib/odbc.py b/yoyo/backends/contrib/odbc.py new file mode 100644 index 0000000..05eb594 --- /dev/null +++ b/yoyo/backends/contrib/odbc.py @@ -0,0 +1,31 @@ +# Copyright 2015 Oliver Cope +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from yoyo.backends.base import DatabaseBackend + + +class ODBCBackend(DatabaseBackend): + driver_module = "pyodbc" + + def connect(self, dburi): + args = [ + ("UID", dburi.username), + ("PWD", dburi.password), + ("ServerName", dburi.hostname), + ("Port", dburi.port), + ("Database", dburi.database), + ] + args.extend(dburi.args.items()) + s = ";".join("{}={}".format(k, v) for k, v in args if v is not None) + return self.driver.connect(s) diff --git a/yoyo/backends/contrib/oracle.py b/yoyo/backends/contrib/oracle.py new file mode 100644 index 0000000..24f644c --- /dev/null +++ b/yoyo/backends/contrib/oracle.py @@ -0,0 +1,46 @@ +# Copyright 2015 Oliver Cope +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from yoyo.backends.base import DatabaseBackend + + +class OracleBackend(DatabaseBackend): + + driver_module = "cx_Oracle" + list_tables_sql = "SELECT table_name FROM all_tables WHERE owner=user" + + def begin(self): + """Oracle is always in a transaction, and has no "BEGIN" statement.""" + self._in_transaction = True + + def connect(self, dburi): + kwargs = dburi.args + if dburi.username is not None: + kwargs["user"] = dburi.username + if dburi.password is not None: + kwargs["password"] = dburi.password + # Oracle combines the hostname, port and database into a single DSN. + # The DSN can also be a "net service name" + kwargs["dsn"] = "" + if dburi.hostname is not None: + kwargs["dsn"] = dburi.hostname + if dburi.port is not None: + kwargs["dsn"] += ":{0}".format(dburi.port) + if dburi.database is not None: + if kwargs["dsn"]: + kwargs["dsn"] += "/{0}".format(dburi.database) + else: + kwargs["dsn"] = dburi.database + + return self.driver.connect(**kwargs) diff --git a/yoyo/backends/contrib/redshift.py b/yoyo/backends/contrib/redshift.py new file mode 100644 index 0000000..80fdb5a --- /dev/null +++ b/yoyo/backends/contrib/redshift.py @@ -0,0 +1,62 @@ +# Copyright 2015 Oliver Cope +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +from datetime import datetime + +from yoyo import exceptions +from yoyo.backends.core.postgresql import PostgresqlBackend + + +class RedshiftBackend(PostgresqlBackend): + def list_tables(self): + current_schema = self.execute("SELECT current_schema()").fetchone()[0] + return super(PostgresqlBackend, self).list_tables(schema=current_schema) + + # Redshift does not support ROLLBACK TO SAVEPOINT + def savepoint(self, id): + pass + + def savepoint_release(self, id): + pass + + def savepoint_rollback(self, id): + self.rollback() + + # Redshift does not enforce primary and unique keys + def _insert_lock_row(self, pid, timeout, poll_interval=0.5): + poll_interval = min(poll_interval, timeout) + started = time.time() + while True: + with self.transaction(): + # prevents isolation violation errors + self.execute("LOCK {}".format(self.lock_table_quoted)) + cursor = self.execute( + "SELECT pid FROM {}".format(self.lock_table_quoted) + ) + row = cursor.fetchone() + if not row: + self.execute( + "INSERT INTO {} (locked, ctime, pid) " + "VALUES (1, :when, :pid)".format(self.lock_table_quoted), + {"when": datetime.utcnow(), "pid": pid}, + ) + return + elif timeout and time.time() > started + timeout: + raise exceptions.LockTimeout( + "Process {} has locked this database " + "(run yoyo break-lock to remove this lock)".format(row[0]) + ) + else: + time.sleep(poll_interval) diff --git a/yoyo/backends/contrib/snowflake.py b/yoyo/backends/contrib/snowflake.py new file mode 100644 index 0000000..244a156 --- /dev/null +++ b/yoyo/backends/contrib/snowflake.py @@ -0,0 +1,40 @@ +# Copyright 2015 Oliver Cope +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from yoyo.backends.base import DatabaseBackend + + +class SnowflakeBackend(DatabaseBackend): + + driver_module = "snowflake.connector" + + def connect(self, dburi): + database, schema = dburi.database.split("/") + return self.driver.connect( + user=dburi.username, + password=dburi.password, + account=dburi.hostname, + database=database, + schema=schema, + **dburi.args, + ) + + def savepoint(self, id): + pass + + def savepoint_release(self, id): + pass + + def savepoint_rollback(self, id): + pass diff --git a/yoyo/backends/core/__init__.py b/yoyo/backends/core/__init__.py new file mode 100644 index 0000000..183e24c --- /dev/null +++ b/yoyo/backends/core/__init__.py @@ -0,0 +1,9 @@ +from yoyo.backends.core.mysql import MySQLBackend +from yoyo.backends.core.sqlite3 import SQLiteBackend +from yoyo.backends.core.postgresql import PostgresqlBackend + +__all__ = [ + "MySQLBackend", + "SQLiteBackend", + "PostgresqlBackend", +] diff --git a/yoyo/backends/core/mysql.py b/yoyo/backends/core/mysql.py new file mode 100644 index 0000000..e8efb97 --- /dev/null +++ b/yoyo/backends/core/mysql.py @@ -0,0 +1,68 @@ +# Copyright 2015 Oliver Cope +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from yoyo.backends.base import DatabaseBackend + + +class MySQLBackend(DatabaseBackend): + + driver_module = "pymysql" + list_tables_sql = ( + "SELECT table_name FROM information_schema.tables " + "WHERE table_schema = :database" + ) + + def connect(self, dburi): + kwargs = {"db": dburi.database} + kwargs.update(dburi.args) + if dburi.username is not None: + kwargs["user"] = dburi.username + if dburi.password is not None: + kwargs["passwd"] = dburi.password + if dburi.hostname is not None: + kwargs["host"] = dburi.hostname + if dburi.port is not None: + kwargs["port"] = dburi.port + if "unix_socket" in dburi.args: + kwargs["unix_socket"] = dburi.args["unix_socket"] + if "ssl" in dburi.args: + kwargs["ssl"] = {} + + if "sslca" in dburi.args: + kwargs["ssl"]["ca"] = dburi.args["sslca"] + + if "sslcapath" in dburi.args: + kwargs["ssl"]["capath"] = dburi.args["sslcapath"] + + if "sslcert" in dburi.args: + kwargs["ssl"]["cert"] = dburi.args["sslcert"] + + if "sslkey" in dburi.args: + kwargs["ssl"]["key"] = dburi.args["sslkey"] + + if "sslcipher" in dburi.args: + kwargs["ssl"]["cipher"] = dburi.args["sslcipher"] + + kwargs["db"] = dburi.database + return self.driver.connect(**kwargs) + + def quote_identifier(self, identifier): + sql_mode = self.execute("SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1] + if "ansi_quotes" in sql_mode.lower(): + return super(MySQLBackend, self).quote_identifier(identifier) + return "`{}`".format(identifier) + + +class MySQLdbBackend(MySQLBackend): + driver_module = "MySQLdb" diff --git a/yoyo/backends/core/postgresql.py b/yoyo/backends/core/postgresql.py new file mode 100644 index 0000000..b5bf3c3 --- /dev/null +++ b/yoyo/backends/core/postgresql.py @@ -0,0 +1,58 @@ +# Copyright 2015 Oliver Cope +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from contextlib import contextmanager + +from yoyo.backends.base import DatabaseBackend + + +class PostgresqlBackend(DatabaseBackend): + + driver_module = "psycopg2" + schema = None + list_tables_sql = ( + "SELECT table_name FROM information_schema.tables " + "WHERE table_schema = :schema" + ) + + def connect(self, dburi): + kwargs = {"dbname": dburi.database} + kwargs.update(dburi.args) + if dburi.username is not None: + kwargs["user"] = dburi.username + if dburi.password is not None: + kwargs["password"] = dburi.password + if dburi.port is not None: + kwargs["port"] = dburi.port + if dburi.hostname is not None: + kwargs["host"] = dburi.hostname + self.schema = kwargs.pop("schema", None) + return self.driver.connect(**kwargs) + + @contextmanager + def disable_transactions(self): + with super(PostgresqlBackend, self).disable_transactions(): + saved = self.connection.autocommit + self.connection.autocommit = True + yield + self.connection.autocommit = saved + + def init_connection(self, connection): + if self.schema: + cursor = connection.cursor() + cursor.execute("SET search_path TO {}".format(self.schema)) + + def list_tables(self): + current_schema = self.execute("SELECT current_schema").fetchone()[0] + return super(PostgresqlBackend, self).list_tables(schema=current_schema) diff --git a/yoyo/backends/core/sqlite3.py b/yoyo/backends/core/sqlite3.py new file mode 100644 index 0000000..43345b0 --- /dev/null +++ b/yoyo/backends/core/sqlite3.py @@ -0,0 +1,32 @@ +# Copyright 2015 Oliver Cope +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from yoyo.backends.base import DatabaseBackend + + +class SQLiteBackend(DatabaseBackend): + + driver_module = "sqlite3" + list_tables_sql = "SELECT name FROM sqlite_master WHERE type = 'table'" + + def connect(self, dburi): + # Ensure that multiple connections share the same data + # https://sqlite.org/sharedcache.html + conn = self.driver.connect( + f"file:{dburi.database}?cache=shared", + uri=True, + detect_types=self.driver.PARSE_DECLTYPES, + ) + conn.isolation_level = None + return conn diff --git a/yoyo/tests/conftest.py b/yoyo/tests/conftest.py index 1a8ce51..1bf8d4a 100644 --- a/yoyo/tests/conftest.py +++ b/yoyo/tests/conftest.py @@ -1,6 +1,6 @@ import pytest -from yoyo import backends +import yoyo.backends.core from yoyo.connections import get_backend from yoyo.tests import dburi_sqlite3 from yoyo.tests import get_test_backends @@ -13,7 +13,7 @@ def _backend(dburi): """ backend = get_backend(dburi) with backend.transaction(): - if backend.__class__ is backends.MySQLBackend: + if backend.__class__ is yoyo.backends.core.MySQLBackend: backend.execute( "CREATE TABLE yoyo_t (id CHAR(1) primary key) ENGINE=InnoDB" ) diff --git a/yoyo/tests/test_backends.py b/yoyo/tests/test_backends.py index c812f3d..8d99009 100644 --- a/yoyo/tests/test_backends.py +++ b/yoyo/tests/test_backends.py @@ -10,6 +10,7 @@ import pytest from yoyo import backends from yoyo import read_migrations from yoyo import exceptions +from yoyo.backends.contrib.redshift import RedshiftBackend from yoyo.connections import get_backend from yoyo.tests import get_test_backends from yoyo.tests import get_test_dburis @@ -74,8 +75,8 @@ class TestTransactionHandling(object): def test_backend_detects_transactional_ddl(self, backend): expected = { + RedshiftBackend: True, backends.PostgresqlBackend: True, - backends.RedshiftBackend: True, backends.SQLiteBackend: True, backends.MySQLBackend: False, } diff --git a/yoyo/tests/test_connections.py b/yoyo/tests/test_connections.py index ec4b70b..10b32c6 100644 --- a/yoyo/tests/test_connections.py +++ b/yoyo/tests/test_connections.py @@ -18,6 +18,8 @@ from unittest.mock import patch, call, MagicMock import pytest from yoyo.connections import parse_uri, BadConnectionURI +from yoyo import backends +from yoyo.backends.contrib import odbc class MockDatabaseError(Exception): @@ -73,17 +75,15 @@ class TestParseURI: @patch( - "yoyo.backends.get_dbapi_module", + "yoyo.backends.base.get_dbapi_module", return_value=MagicMock(DatabaseError=MockDatabaseError, paramstyle="qmark"), ) def test_connections(get_dbapi_module): - from yoyo import backends - u = parse_uri("odbc://scott:tiger@db.example.org:42/northwind?foo=bar") cases = [ ( - backends.ODBCBackend, + odbc.ODBCBackend, "pyodbc", call( "UID=scott;PWD=tiger;ServerName=db.example.org;" |