diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 9d2d519922..035f9ea6e9 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -12,29 +12,31 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
-import importlib
import platform
-from ._base import IncorrectDatabaseSetup
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
from .postgres import PostgresEngine
from .sqlite import Sqlite3Engine
-SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
-
-def create_engine(database_config):
+def create_engine(database_config) -> BaseDatabaseEngine:
name = database_config["name"]
- engine_class = SUPPORTED_MODULE.get(name, None)
- if engine_class:
+ if name == "sqlite3":
+ import sqlite3
+
+ return Sqlite3Engine(sqlite3, database_config)
+
+ if name == "psycopg2":
# pypy requires psycopg2cffi rather than psycopg2
- if name == "psycopg2" and platform.python_implementation() == "PyPy":
- name = "psycopg2cffi"
- module = importlib.import_module(name)
- return engine_class(module, database_config)
+ if platform.python_implementation() == "PyPy":
+ import psycopg2cffi as psycopg2 # type: ignore
+ else:
+ import psycopg2 # type: ignore
+
+ return PostgresEngine(psycopg2, database_config)
raise RuntimeError("Unsupported database engine '%s'" % (name,))
-__all__ = ["create_engine", "IncorrectDatabaseSetup"]
+__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ec5a4d198b..ab0bbe4bd3 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -12,7 +12,94 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import abc
+from typing import Generic, TypeVar
+
+from synapse.storage.types import Connection
class IncorrectDatabaseSetup(RuntimeError):
pass
+
+
+ConnectionType = TypeVar("ConnectionType", bound=Connection)
+
+
+class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
+ def __init__(self, module, database_config: dict):
+ self.module = module
+
+ @property
+ @abc.abstractmethod
+ def single_threaded(self) -> bool:
+ ...
+
+ @property
+ @abc.abstractmethod
+ def can_native_upsert(self) -> bool:
+ """
+ Do we support native UPSERTs?
+ """
+ ...
+
+ @property
+ @abc.abstractmethod
+ def supports_tuple_comparison(self) -> bool:
+ """
+ Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
+ """
+ ...
+
+ @property
+ @abc.abstractmethod
+ def supports_using_any_list(self) -> bool:
+ """
+ Do we support using `a = ANY(?)` and passing a list
+ """
+ ...
+
+ @abc.abstractmethod
+ def check_database(
+ self, db_conn: ConnectionType, allow_outdated_version: bool = False
+ ) -> None:
+ ...
+
+ @abc.abstractmethod
+ def check_new_database(self, txn) -> None:
+ """Gets called when setting up a brand new database. This allows us to
+ apply stricter checks on new databases versus existing database.
+ """
+ ...
+
+ @abc.abstractmethod
+ def convert_param_style(self, sql: str) -> str:
+ ...
+
+ @abc.abstractmethod
+ def on_new_connection(self, db_conn: ConnectionType) -> None:
+ ...
+
+ @abc.abstractmethod
+ def is_deadlock(self, error: Exception) -> bool:
+ ...
+
+ @abc.abstractmethod
+ def is_connection_closed(self, conn: ConnectionType) -> bool:
+ ...
+
+ @abc.abstractmethod
+ def lock_table(self, txn, table: str) -> None:
+ ...
+
+ @abc.abstractmethod
+ def get_next_state_group_id(self, txn) -> int:
+ """Returns an int that can be used as a new state_group ID
+ """
+ ...
+
+ @property
+ @abc.abstractmethod
+ def server_version(self) -> str:
+ """Gets a string giving the server version. For example: '3.22.0'
+ """
+ ...
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 1b97ee74e3..6c7d08a6f2 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -13,38 +13,97 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from ._base import IncorrectDatabaseSetup
+import logging
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
-class PostgresEngine(object):
- single_threaded = False
+logger = logging.getLogger(__name__)
+
+class PostgresEngine(BaseDatabaseEngine):
def __init__(self, database_module, database_config):
- self.module = database_module
+ super().__init__(database_module, database_config)
self.module.extensions.register_type(self.module.extensions.UNICODE)
- self.synchronous_commit = database_config.get("synchronous_commit", True)
- self._version = None # unknown as yet
- def check_database(self, txn):
- txn.execute("SHOW SERVER_ENCODING")
- rows = txn.fetchall()
- if rows and rows[0][0] != "UTF8":
- raise IncorrectDatabaseSetup(
- "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
- "See docs/postgres.rst for more information." % (rows[0][0],)
- )
+ # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
+ # actually want to use bytes than wrap it in `bytearray`.
+ def _disable_bytes_adapter(_):
+ raise Exception("Passing bytes to DB is disabled.")
- def convert_param_style(self, sql):
- return sql.replace("?", "%s")
+ self.module.extensions.register_adapter(bytes, _disable_bytes_adapter)
+ self.synchronous_commit = database_config.get("synchronous_commit", True)
+ self._version = None # unknown as yet
- def on_new_connection(self, db_conn):
+ @property
+ def single_threaded(self) -> bool:
+ return False
+ def check_database(self, db_conn, allow_outdated_version: bool = False):
# Get the version of PostgreSQL that we're using. As per the psycopg2
# docs: The number is formed by converting the major, minor, and
# revision numbers into two-decimal-digit numbers and appending them
# together. For example, version 8.1.5 will be returned as 80105
self._version = db_conn.server_version
+ # Are we on a supported PostgreSQL version?
+ if not allow_outdated_version and self._version < 90500:
+ raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.")
+
+ with db_conn.cursor() as txn:
+ txn.execute("SHOW SERVER_ENCODING")
+ rows = txn.fetchall()
+ if rows and rows[0][0] != "UTF8":
+ raise IncorrectDatabaseSetup(
+ "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
+ "See docs/postgres.md for more information." % (rows[0][0],)
+ )
+
+ txn.execute(
+ "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+ )
+ collation, ctype = txn.fetchone()
+ if collation != "C":
+ logger.warning(
+ "Database has incorrect collation of %r. Should be 'C'\n"
+ "See docs/postgres.md for more information.",
+ collation,
+ )
+
+ if ctype != "C":
+ logger.warning(
+ "Database has incorrect ctype of %r. Should be 'C'\n"
+ "See docs/postgres.md for more information.",
+ ctype,
+ )
+
+ def check_new_database(self, txn):
+ """Gets called when setting up a brand new database. This allows us to
+ apply stricter checks on new databases versus existing database.
+ """
+
+ txn.execute(
+ "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()"
+ )
+ collation, ctype = txn.fetchone()
+
+ errors = []
+
+ if collation != "C":
+ errors.append(" - 'COLLATE' is set to %r. Should be 'C'" % (collation,))
+
+ if ctype != "C":
+ errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (collation,))
+
+ if errors:
+ raise IncorrectDatabaseSetup(
+ "Database is incorrectly configured:\n\n%s\n\n"
+ "See docs/postgres.md for more information." % ("\n".join(errors))
+ )
+
+ def convert_param_style(self, sql):
+ return sql.replace("?", "%s")
+
+ def on_new_connection(self, db_conn):
db_conn.set_isolation_level(
self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ
)
@@ -64,9 +123,22 @@ class PostgresEngine(object):
@property
def can_native_upsert(self):
"""
- Can we use native UPSERTs? This requires PostgreSQL 9.5+.
+ Can we use native UPSERTs?
+ """
+ return True
+
+ @property
+ def supports_tuple_comparison(self):
+ """
+ Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
+ """
+ return True
+
+ @property
+ def supports_using_any_list(self):
+ """Do we support using `a = ANY(?)` and passing a list
"""
- return self._version >= 90500
+ return True
def is_deadlock(self, error):
if isinstance(error, self.module.DatabaseError):
@@ -95,8 +167,8 @@ class PostgresEngine(object):
Returns:
string
"""
- # note that this is a bit of a hack because it relies on on_new_connection
- # having been called at least once. Still, that should be a safe bet here.
+ # note that this is a bit of a hack because it relies on check_database
+ # having been called. Still, that should be a safe bet here.
numver = self._version
assert numver is not None
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 933bcf42c2..3bc2e8b986 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,18 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-
import struct
import threading
+import typing
-from synapse.storage.prepare_database import prepare_database
+from synapse.storage.engines import BaseDatabaseEngine
+if typing.TYPE_CHECKING:
+ import sqlite3 # noqa: F401
-class Sqlite3Engine(object):
- single_threaded = True
+class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]):
def __init__(self, database_module, database_config):
- self.module = database_module
+ super().__init__(database_module, database_config)
+
+ database = database_config.get("args", {}).get("database")
+ self._is_in_memory = database in (None, ":memory:",)
# The current max state_group, or None if we haven't looked
# in the DB yet.
@@ -31,6 +35,10 @@ class Sqlite3Engine(object):
self._current_state_group_id_lock = threading.Lock()
@property
+ def single_threaded(self) -> bool:
+ return True
+
+ @property
def can_native_upsert(self):
"""
Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
@@ -38,14 +46,44 @@ class Sqlite3Engine(object):
"""
return self.module.sqlite_version_info >= (3, 24, 0)
- def check_database(self, txn):
- pass
+ @property
+ def supports_tuple_comparison(self):
+ """
+ Do we support comparing tuples, i.e. `(a, b) > (c, d)`? This requires
+ SQLite 3.15+.
+ """
+ return self.module.sqlite_version_info >= (3, 15, 0)
+
+ @property
+ def supports_using_any_list(self):
+ """Do we support using `a = ANY(?)` and passing a list
+ """
+ return False
+
+ def check_database(self, db_conn, allow_outdated_version: bool = False):
+ if not allow_outdated_version:
+ version = self.module.sqlite_version_info
+ if version < (3, 11, 0):
+ raise RuntimeError("Synapse requires sqlite 3.11 or above.")
+
+ def check_new_database(self, txn):
+ """Gets called when setting up a brand new database. This allows us to
+ apply stricter checks on new databases versus existing database.
+ """
def convert_param_style(self, sql):
return sql
def on_new_connection(self, db_conn):
- prepare_database(db_conn, self, config=None)
+ # We need to import here to avoid an import loop.
+ from synapse.storage.prepare_database import prepare_database
+
+ if self._is_in_memory:
+ # In memory databases need to be rebuilt each time. Ideally we'd
+ # reuse the same connection as we do when starting up, but that
+ # would involve using adbapi before we have started the reactor.
+ prepare_database(db_conn, self, config=None)
+
db_conn.create_function("rank", 1, _rank)
def is_deadlock(self, error):
@@ -85,7 +123,7 @@ class Sqlite3Engine(object):
def _parse_match_info(buf):
bufsize = len(buf)
- return [struct.unpack('@I', buf[i : i + 4])[0] for i in range(0, bufsize, 4)]
+ return [struct.unpack("@I", buf[i : i + 4])[0] for i in range(0, bufsize, 4)]
def _rank(raw_match_info):
|