summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorOlly Cope <olly@ollycope.com>2022-09-01 12:28:48 +0000
committerOlly Cope <olly@ollycope.com>2022-09-01 12:28:48 +0000
commit030c45f0d17144a681d56de85857eb5819e1c09d (patch)
treef053eeef1def3d38c4702a99ad2b6c4a058fe336
parent290150383e2a12628ccc72c4f4fed252a36d0f40 (diff)
downloadyoyo-030c45f0d17144a681d56de85857eb5819e1c09d.tar.gz
Split yoyo.backends into a package
-rw-r--r--setup.cfg23
-rw-r--r--yoyo/backends/__init__.py13
-rw-r--r--yoyo/backends/base.py (renamed from yoyo/backends.py)238
-rw-r--r--yoyo/backends/contrib/__init__.py0
-rw-r--r--yoyo/backends/contrib/odbc.py31
-rw-r--r--yoyo/backends/contrib/oracle.py46
-rw-r--r--yoyo/backends/contrib/redshift.py62
-rw-r--r--yoyo/backends/contrib/snowflake.py40
-rw-r--r--yoyo/backends/core/__init__.py9
-rw-r--r--yoyo/backends/core/mysql.py68
-rw-r--r--yoyo/backends/core/postgresql.py58
-rw-r--r--yoyo/backends/core/sqlite3.py32
-rw-r--r--yoyo/tests/conftest.py4
-rw-r--r--yoyo/tests/test_backends.py3
-rw-r--r--yoyo/tests/test_connections.py8
15 files changed, 386 insertions, 249 deletions
diff --git a/setup.cfg b/setup.cfg
index 12f68ec..edda49f 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -23,6 +23,9 @@ classifiers =
[options]
packages =
yoyo
+ yoyo.backends
+ yoyo.backends.core
+ yoyo.backends.contrib
yoyo.scripts
yoyo.internalmigrations
install_requires =
@@ -44,13 +47,13 @@ console_scripts =
yoyo-migrate = yoyo.scripts.main:main
yoyo.backends =
- odbc = yoyo.backends:ODBCBackend
- oracle = yoyo.backends:OracleBackend
- postgresql = yoyo.backends:PostgresqlBackend
- postgres = yoyo.backends:PostgresqlBackend
- psql = yoyo.backends:PostgresqlBackend
- mysql = yoyo.backends:MySQLBackend
- mysql+mysqldb = yoyo.backends:MySQLdbBackend
- sqlite = yoyo.backends:SQLiteBackend
- snowflake = yoyo.backends:SnowflakeBackend
- redshift = yoyo.backends:RedshiftBackend
+ odbc = yoyo.backends.contrib.odbc:ODBCBackend
+ oracle = yoyo.backends.contrib.oracle:OracleBackend
+ postgres = yoyo.backends.core.postgresql:PostgresqlBackend
+ postgresql = yoyo.backends.core.postgresql:PostgresqlBackend
+ psql = yoyo.backends.core.postgresql:PostgresqlBackend
+ mysql = yoyo.backends.core.mysql:MySQLBackend
+ mysql+mysqldb = yoyo.backends.core.mysql:MySQLdbBackend
+ sqlite = yoyo.backends.core.sqlite3:SQLiteBackend
+ snowflake = yoyo.backends.contrib.snowflake:SnowflakeBackend
+ redshift = yoyo.backends.contrib.redshift:RedshiftBackend
diff --git a/yoyo/backends/__init__.py b/yoyo/backends/__init__.py
new file mode 100644
index 0000000..c4d6f6e
--- /dev/null
+++ b/yoyo/backends/__init__.py
@@ -0,0 +1,13 @@
+from yoyo.backends.base import DatabaseBackend
+from yoyo.backends.base import get_backend_class
+from yoyo.backends.core import MySQLBackend
+from yoyo.backends.core import SQLiteBackend
+from yoyo.backends.core import PostgresqlBackend
+
+__all__ = [
+ "DatabaseBackend",
+ "get_backend_class",
+ "MySQLBackend",
+ "SQLiteBackend",
+ "PostgresqlBackend",
+]
diff --git a/yoyo/backends.py b/yoyo/backends/base.py
index dc60130..f288d19 100644
--- a/yoyo/backends.py
+++ b/yoyo/backends/base.py
@@ -28,15 +28,15 @@ import socket
import time
import uuid
-from . import exceptions
-from . import internalmigrations
-from . import utils
-from .migrations import topological_sort
+from yoyo import exceptions
+from yoyo import internalmigrations
+from yoyo import utils
+from yoyo.migrations import topological_sort
logger = getLogger("yoyo.migrations")
-class TransactionManager(object):
+class TransactionManager:
"""
Returned by the :meth:`~yoyo.backends.DatabaseBackend.transaction`
context manager.
@@ -112,7 +112,7 @@ class SavepointTransactionManager(TransactionManager):
self.backend.savepoint_rollback(self.id)
-class DatabaseBackend(object):
+class DatabaseBackend:
driver_module = "" # type: str
@@ -566,232 +566,6 @@ class DatabaseBackend(object):
}
-class ODBCBackend(DatabaseBackend):
- driver_module = "pyodbc"
-
- def connect(self, dburi):
- args = [
- ("UID", dburi.username),
- ("PWD", dburi.password),
- ("ServerName", dburi.hostname),
- ("Port", dburi.port),
- ("Database", dburi.database),
- ]
- args.extend(dburi.args.items())
- s = ";".join("{}={}".format(k, v) for k, v in args if v is not None)
- return self.driver.connect(s)
-
-
-class OracleBackend(DatabaseBackend):
-
- driver_module = "cx_Oracle"
- list_tables_sql = "SELECT table_name FROM all_tables WHERE owner=user"
-
- def begin(self):
- """Oracle is always in a transaction, and has no "BEGIN" statement."""
- self._in_transaction = True
-
- def connect(self, dburi):
- kwargs = dburi.args
- if dburi.username is not None:
- kwargs["user"] = dburi.username
- if dburi.password is not None:
- kwargs["password"] = dburi.password
- # Oracle combines the hostname, port and database into a single DSN.
- # The DSN can also be a "net service name"
- kwargs["dsn"] = ""
- if dburi.hostname is not None:
- kwargs["dsn"] = dburi.hostname
- if dburi.port is not None:
- kwargs["dsn"] += ":{0}".format(dburi.port)
- if dburi.database is not None:
- if kwargs["dsn"]:
- kwargs["dsn"] += "/{0}".format(dburi.database)
- else:
- kwargs["dsn"] = dburi.database
-
- return self.driver.connect(**kwargs)
-
-
-class MySQLBackend(DatabaseBackend):
-
- driver_module = "pymysql"
- list_tables_sql = (
- "SELECT table_name FROM information_schema.tables "
- "WHERE table_schema = :database"
- )
-
- def connect(self, dburi):
- kwargs = {"db": dburi.database}
- kwargs.update(dburi.args)
- if dburi.username is not None:
- kwargs["user"] = dburi.username
- if dburi.password is not None:
- kwargs["passwd"] = dburi.password
- if dburi.hostname is not None:
- kwargs["host"] = dburi.hostname
- if dburi.port is not None:
- kwargs["port"] = dburi.port
- if "unix_socket" in dburi.args:
- kwargs["unix_socket"] = dburi.args["unix_socket"]
- if "ssl" in dburi.args:
- kwargs["ssl"] = {}
-
- if "sslca" in dburi.args:
- kwargs["ssl"]["ca"] = dburi.args["sslca"]
-
- if "sslcapath" in dburi.args:
- kwargs["ssl"]["capath"] = dburi.args["sslcapath"]
-
- if "sslcert" in dburi.args:
- kwargs["ssl"]["cert"] = dburi.args["sslcert"]
-
- if "sslkey" in dburi.args:
- kwargs["ssl"]["key"] = dburi.args["sslkey"]
-
- if "sslcipher" in dburi.args:
- kwargs["ssl"]["cipher"] = dburi.args["sslcipher"]
-
- kwargs["db"] = dburi.database
- return self.driver.connect(**kwargs)
-
- def quote_identifier(self, identifier):
- sql_mode = self.execute("SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1]
- if "ansi_quotes" in sql_mode.lower():
- return super(MySQLBackend, self).quote_identifier(identifier)
- return "`{}`".format(identifier)
-
-
-class MySQLdbBackend(MySQLBackend):
- driver_module = "MySQLdb"
-
-
-class SQLiteBackend(DatabaseBackend):
-
- driver_module = "sqlite3"
- list_tables_sql = "SELECT name FROM sqlite_master WHERE type = 'table'"
-
- def connect(self, dburi):
- # Ensure that multiple connections share the same data
- # https://sqlite.org/sharedcache.html
- conn = self.driver.connect(
- f"file:{dburi.database}?cache=shared",
- uri=True,
- detect_types=self.driver.PARSE_DECLTYPES,
- )
- conn.isolation_level = None
- return conn
-
-
-class PostgresqlBackend(DatabaseBackend):
-
- driver_module = "psycopg2"
- schema = None
- list_tables_sql = (
- "SELECT table_name FROM information_schema.tables "
- "WHERE table_schema = :schema"
- )
-
- def connect(self, dburi):
- kwargs = {"dbname": dburi.database}
- kwargs.update(dburi.args)
- if dburi.username is not None:
- kwargs["user"] = dburi.username
- if dburi.password is not None:
- kwargs["password"] = dburi.password
- if dburi.port is not None:
- kwargs["port"] = dburi.port
- if dburi.hostname is not None:
- kwargs["host"] = dburi.hostname
- self.schema = kwargs.pop("schema", None)
- return self.driver.connect(**kwargs)
-
- @contextmanager
- def disable_transactions(self):
- with super(PostgresqlBackend, self).disable_transactions():
- saved = self.connection.autocommit
- self.connection.autocommit = True
- yield
- self.connection.autocommit = saved
-
- def init_connection(self, connection):
- if self.schema:
- cursor = connection.cursor()
- cursor.execute("SET search_path TO {}".format(self.schema))
-
- def list_tables(self):
- current_schema = self.execute("SELECT current_schema").fetchone()[0]
- return super(PostgresqlBackend, self).list_tables(schema=current_schema)
-
-
-class RedshiftBackend(PostgresqlBackend):
- def list_tables(self):
- current_schema = self.execute("SELECT current_schema()").fetchone()[0]
- return super(PostgresqlBackend, self).list_tables(schema=current_schema)
-
- # Redshift does not support ROLLBACK TO SAVEPOINT
- def savepoint(self, id):
- pass
-
- def savepoint_release(self, id):
- pass
-
- def savepoint_rollback(self, id):
- self.rollback()
-
- # Redshift does not enforce primary and unique keys
- def _insert_lock_row(self, pid, timeout, poll_interval=0.5):
- poll_interval = min(poll_interval, timeout)
- started = time.time()
- while True:
- with self.transaction():
- # prevents isolation violation errors
- self.execute("LOCK {}".format(self.lock_table_quoted))
- cursor = self.execute(
- "SELECT pid FROM {}".format(self.lock_table_quoted)
- )
- row = cursor.fetchone()
- if not row:
- self.execute(
- "INSERT INTO {} (locked, ctime, pid) "
- "VALUES (1, :when, :pid)".format(self.lock_table_quoted),
- {"when": datetime.utcnow(), "pid": pid},
- )
- return
- elif timeout and time.time() > started + timeout:
- raise exceptions.LockTimeout(
- "Process {} has locked this database "
- "(run yoyo break-lock to remove this lock)".format(row[0])
- )
- else:
- time.sleep(poll_interval)
-
-
-class SnowflakeBackend(DatabaseBackend):
-
- driver_module = "snowflake.connector"
-
- def connect(self, dburi):
- database, schema = dburi.database.split("/")
- return self.driver.connect(
- user=dburi.username,
- password=dburi.password,
- account=dburi.hostname,
- database=database,
- schema=schema,
- **dburi.args,
- )
-
- def savepoint(self, id):
- pass
-
- def savepoint_release(self, id):
- pass
-
- def savepoint_rollback(self, id):
- pass
-
-
def get_backend_class(name):
backend_eps = entry_points(group="yoyo.backends")
return backend_eps[name].load()
diff --git a/yoyo/backends/contrib/__init__.py b/yoyo/backends/contrib/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/yoyo/backends/contrib/__init__.py
diff --git a/yoyo/backends/contrib/odbc.py b/yoyo/backends/contrib/odbc.py
new file mode 100644
index 0000000..05eb594
--- /dev/null
+++ b/yoyo/backends/contrib/odbc.py
@@ -0,0 +1,31 @@
+# Copyright 2015 Oliver Cope
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from yoyo.backends.base import DatabaseBackend
+
+
+class ODBCBackend(DatabaseBackend):
+ driver_module = "pyodbc"
+
+ def connect(self, dburi):
+ args = [
+ ("UID", dburi.username),
+ ("PWD", dburi.password),
+ ("ServerName", dburi.hostname),
+ ("Port", dburi.port),
+ ("Database", dburi.database),
+ ]
+ args.extend(dburi.args.items())
+ s = ";".join("{}={}".format(k, v) for k, v in args if v is not None)
+ return self.driver.connect(s)
diff --git a/yoyo/backends/contrib/oracle.py b/yoyo/backends/contrib/oracle.py
new file mode 100644
index 0000000..24f644c
--- /dev/null
+++ b/yoyo/backends/contrib/oracle.py
@@ -0,0 +1,46 @@
+# Copyright 2015 Oliver Cope
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from yoyo.backends.base import DatabaseBackend
+
+
+class OracleBackend(DatabaseBackend):
+
+ driver_module = "cx_Oracle"
+ list_tables_sql = "SELECT table_name FROM all_tables WHERE owner=user"
+
+ def begin(self):
+ """Oracle is always in a transaction, and has no "BEGIN" statement."""
+ self._in_transaction = True
+
+ def connect(self, dburi):
+ kwargs = dburi.args
+ if dburi.username is not None:
+ kwargs["user"] = dburi.username
+ if dburi.password is not None:
+ kwargs["password"] = dburi.password
+ # Oracle combines the hostname, port and database into a single DSN.
+ # The DSN can also be a "net service name"
+ kwargs["dsn"] = ""
+ if dburi.hostname is not None:
+ kwargs["dsn"] = dburi.hostname
+ if dburi.port is not None:
+ kwargs["dsn"] += ":{0}".format(dburi.port)
+ if dburi.database is not None:
+ if kwargs["dsn"]:
+ kwargs["dsn"] += "/{0}".format(dburi.database)
+ else:
+ kwargs["dsn"] = dburi.database
+
+ return self.driver.connect(**kwargs)
diff --git a/yoyo/backends/contrib/redshift.py b/yoyo/backends/contrib/redshift.py
new file mode 100644
index 0000000..80fdb5a
--- /dev/null
+++ b/yoyo/backends/contrib/redshift.py
@@ -0,0 +1,62 @@
+# Copyright 2015 Oliver Cope
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from datetime import datetime
+
+from yoyo import exceptions
+from yoyo.backends.core.postgresql import PostgresqlBackend
+
+
+class RedshiftBackend(PostgresqlBackend):
+ def list_tables(self):
+ current_schema = self.execute("SELECT current_schema()").fetchone()[0]
+ return super(PostgresqlBackend, self).list_tables(schema=current_schema)
+
+ # Redshift does not support ROLLBACK TO SAVEPOINT
+ def savepoint(self, id):
+ pass
+
+ def savepoint_release(self, id):
+ pass
+
+ def savepoint_rollback(self, id):
+ self.rollback()
+
+ # Redshift does not enforce primary and unique keys
+ def _insert_lock_row(self, pid, timeout, poll_interval=0.5):
+ poll_interval = min(poll_interval, timeout)
+ started = time.time()
+ while True:
+ with self.transaction():
+ # prevents isolation violation errors
+ self.execute("LOCK {}".format(self.lock_table_quoted))
+ cursor = self.execute(
+ "SELECT pid FROM {}".format(self.lock_table_quoted)
+ )
+ row = cursor.fetchone()
+ if not row:
+ self.execute(
+ "INSERT INTO {} (locked, ctime, pid) "
+ "VALUES (1, :when, :pid)".format(self.lock_table_quoted),
+ {"when": datetime.utcnow(), "pid": pid},
+ )
+ return
+ elif timeout and time.time() > started + timeout:
+ raise exceptions.LockTimeout(
+ "Process {} has locked this database "
+ "(run yoyo break-lock to remove this lock)".format(row[0])
+ )
+ else:
+ time.sleep(poll_interval)
diff --git a/yoyo/backends/contrib/snowflake.py b/yoyo/backends/contrib/snowflake.py
new file mode 100644
index 0000000..244a156
--- /dev/null
+++ b/yoyo/backends/contrib/snowflake.py
@@ -0,0 +1,40 @@
+# Copyright 2015 Oliver Cope
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from yoyo.backends.base import DatabaseBackend
+
+
+class SnowflakeBackend(DatabaseBackend):
+
+ driver_module = "snowflake.connector"
+
+ def connect(self, dburi):
+ database, schema = dburi.database.split("/")
+ return self.driver.connect(
+ user=dburi.username,
+ password=dburi.password,
+ account=dburi.hostname,
+ database=database,
+ schema=schema,
+ **dburi.args,
+ )
+
+ def savepoint(self, id):
+ pass
+
+ def savepoint_release(self, id):
+ pass
+
+ def savepoint_rollback(self, id):
+ pass
diff --git a/yoyo/backends/core/__init__.py b/yoyo/backends/core/__init__.py
new file mode 100644
index 0000000..183e24c
--- /dev/null
+++ b/yoyo/backends/core/__init__.py
@@ -0,0 +1,9 @@
+from yoyo.backends.core.mysql import MySQLBackend
+from yoyo.backends.core.sqlite3 import SQLiteBackend
+from yoyo.backends.core.postgresql import PostgresqlBackend
+
+__all__ = [
+ "MySQLBackend",
+ "SQLiteBackend",
+ "PostgresqlBackend",
+]
diff --git a/yoyo/backends/core/mysql.py b/yoyo/backends/core/mysql.py
new file mode 100644
index 0000000..e8efb97
--- /dev/null
+++ b/yoyo/backends/core/mysql.py
@@ -0,0 +1,68 @@
+# Copyright 2015 Oliver Cope
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from yoyo.backends.base import DatabaseBackend
+
+
+class MySQLBackend(DatabaseBackend):
+
+ driver_module = "pymysql"
+ list_tables_sql = (
+ "SELECT table_name FROM information_schema.tables "
+ "WHERE table_schema = :database"
+ )
+
+ def connect(self, dburi):
+ kwargs = {"db": dburi.database}
+ kwargs.update(dburi.args)
+ if dburi.username is not None:
+ kwargs["user"] = dburi.username
+ if dburi.password is not None:
+ kwargs["passwd"] = dburi.password
+ if dburi.hostname is not None:
+ kwargs["host"] = dburi.hostname
+ if dburi.port is not None:
+ kwargs["port"] = dburi.port
+ if "unix_socket" in dburi.args:
+ kwargs["unix_socket"] = dburi.args["unix_socket"]
+ if "ssl" in dburi.args:
+ kwargs["ssl"] = {}
+
+ if "sslca" in dburi.args:
+ kwargs["ssl"]["ca"] = dburi.args["sslca"]
+
+ if "sslcapath" in dburi.args:
+ kwargs["ssl"]["capath"] = dburi.args["sslcapath"]
+
+ if "sslcert" in dburi.args:
+ kwargs["ssl"]["cert"] = dburi.args["sslcert"]
+
+ if "sslkey" in dburi.args:
+ kwargs["ssl"]["key"] = dburi.args["sslkey"]
+
+ if "sslcipher" in dburi.args:
+ kwargs["ssl"]["cipher"] = dburi.args["sslcipher"]
+
+ kwargs["db"] = dburi.database
+ return self.driver.connect(**kwargs)
+
+ def quote_identifier(self, identifier):
+ sql_mode = self.execute("SHOW VARIABLES LIKE 'sql_mode'").fetchone()[1]
+ if "ansi_quotes" in sql_mode.lower():
+ return super(MySQLBackend, self).quote_identifier(identifier)
+ return "`{}`".format(identifier)
+
+
+class MySQLdbBackend(MySQLBackend):
+ driver_module = "MySQLdb"
diff --git a/yoyo/backends/core/postgresql.py b/yoyo/backends/core/postgresql.py
new file mode 100644
index 0000000..b5bf3c3
--- /dev/null
+++ b/yoyo/backends/core/postgresql.py
@@ -0,0 +1,58 @@
+# Copyright 2015 Oliver Cope
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from contextlib import contextmanager
+
+from yoyo.backends.base import DatabaseBackend
+
+
+class PostgresqlBackend(DatabaseBackend):
+
+ driver_module = "psycopg2"
+ schema = None
+ list_tables_sql = (
+ "SELECT table_name FROM information_schema.tables "
+ "WHERE table_schema = :schema"
+ )
+
+ def connect(self, dburi):
+ kwargs = {"dbname": dburi.database}
+ kwargs.update(dburi.args)
+ if dburi.username is not None:
+ kwargs["user"] = dburi.username
+ if dburi.password is not None:
+ kwargs["password"] = dburi.password
+ if dburi.port is not None:
+ kwargs["port"] = dburi.port
+ if dburi.hostname is not None:
+ kwargs["host"] = dburi.hostname
+ self.schema = kwargs.pop("schema", None)
+ return self.driver.connect(**kwargs)
+
+ @contextmanager
+ def disable_transactions(self):
+ with super(PostgresqlBackend, self).disable_transactions():
+ saved = self.connection.autocommit
+ self.connection.autocommit = True
+ yield
+ self.connection.autocommit = saved
+
+ def init_connection(self, connection):
+ if self.schema:
+ cursor = connection.cursor()
+ cursor.execute("SET search_path TO {}".format(self.schema))
+
+ def list_tables(self):
+ current_schema = self.execute("SELECT current_schema").fetchone()[0]
+ return super(PostgresqlBackend, self).list_tables(schema=current_schema)
diff --git a/yoyo/backends/core/sqlite3.py b/yoyo/backends/core/sqlite3.py
new file mode 100644
index 0000000..43345b0
--- /dev/null
+++ b/yoyo/backends/core/sqlite3.py
@@ -0,0 +1,32 @@
+# Copyright 2015 Oliver Cope
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from yoyo.backends.base import DatabaseBackend
+
+
+class SQLiteBackend(DatabaseBackend):
+
+ driver_module = "sqlite3"
+ list_tables_sql = "SELECT name FROM sqlite_master WHERE type = 'table'"
+
+ def connect(self, dburi):
+ # Ensure that multiple connections share the same data
+ # https://sqlite.org/sharedcache.html
+ conn = self.driver.connect(
+ f"file:{dburi.database}?cache=shared",
+ uri=True,
+ detect_types=self.driver.PARSE_DECLTYPES,
+ )
+ conn.isolation_level = None
+ return conn
diff --git a/yoyo/tests/conftest.py b/yoyo/tests/conftest.py
index 1a8ce51..1bf8d4a 100644
--- a/yoyo/tests/conftest.py
+++ b/yoyo/tests/conftest.py
@@ -1,6 +1,6 @@
import pytest
-from yoyo import backends
+import yoyo.backends.core
from yoyo.connections import get_backend
from yoyo.tests import dburi_sqlite3
from yoyo.tests import get_test_backends
@@ -13,7 +13,7 @@ def _backend(dburi):
"""
backend = get_backend(dburi)
with backend.transaction():
- if backend.__class__ is backends.MySQLBackend:
+ if backend.__class__ is yoyo.backends.core.MySQLBackend:
backend.execute(
"CREATE TABLE yoyo_t (id CHAR(1) primary key) ENGINE=InnoDB"
)
diff --git a/yoyo/tests/test_backends.py b/yoyo/tests/test_backends.py
index c812f3d..8d99009 100644
--- a/yoyo/tests/test_backends.py
+++ b/yoyo/tests/test_backends.py
@@ -10,6 +10,7 @@ import pytest
from yoyo import backends
from yoyo import read_migrations
from yoyo import exceptions
+from yoyo.backends.contrib.redshift import RedshiftBackend
from yoyo.connections import get_backend
from yoyo.tests import get_test_backends
from yoyo.tests import get_test_dburis
@@ -74,8 +75,8 @@ class TestTransactionHandling(object):
def test_backend_detects_transactional_ddl(self, backend):
expected = {
+ RedshiftBackend: True,
backends.PostgresqlBackend: True,
- backends.RedshiftBackend: True,
backends.SQLiteBackend: True,
backends.MySQLBackend: False,
}
diff --git a/yoyo/tests/test_connections.py b/yoyo/tests/test_connections.py
index ec4b70b..10b32c6 100644
--- a/yoyo/tests/test_connections.py
+++ b/yoyo/tests/test_connections.py
@@ -18,6 +18,8 @@ from unittest.mock import patch, call, MagicMock
import pytest
from yoyo.connections import parse_uri, BadConnectionURI
+from yoyo import backends
+from yoyo.backends.contrib import odbc
class MockDatabaseError(Exception):
@@ -73,17 +75,15 @@ class TestParseURI:
@patch(
- "yoyo.backends.get_dbapi_module",
+ "yoyo.backends.base.get_dbapi_module",
return_value=MagicMock(DatabaseError=MockDatabaseError, paramstyle="qmark"),
)
def test_connections(get_dbapi_module):
- from yoyo import backends
-
u = parse_uri("odbc://scott:tiger@db.example.org:42/northwind?foo=bar")
cases = [
(
- backends.ODBCBackend,
+ odbc.ODBCBackend,
"pyodbc",
call(
"UID=scott;PWD=tiger;ServerName=db.example.org;"