summary refs log tree commit diff
path: root/synapse/storage/engines
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/engines')
-rw-r--r--synapse/storage/engines/__init__.py28
-rw-r--r--synapse/storage/engines/_base.py87
-rw-r--r--synapse/storage/engines/postgres.py22
-rw-r--r--synapse/storage/engines/sqlite.py13
4 files changed, 124 insertions, 26 deletions
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 9d2d519922..035f9ea6e9 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -12,29 +12,31 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-import importlib
 import platform
 
-from ._base import IncorrectDatabaseSetup
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
 from .postgres import PostgresEngine
 from .sqlite import Sqlite3Engine
 
-SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
-
 
-def create_engine(database_config):
+def create_engine(database_config) -> BaseDatabaseEngine:
     name = database_config["name"]
-    engine_class = SUPPORTED_MODULE.get(name, None)
 
-    if engine_class:
+    if name == "sqlite3":
+        import sqlite3
+
+        return Sqlite3Engine(sqlite3, database_config)
+
+    if name == "psycopg2":
         # pypy requires psycopg2cffi rather than psycopg2
-        if name == "psycopg2" and platform.python_implementation() == "PyPy":
-            name = "psycopg2cffi"
-        module = importlib.import_module(name)
-        return engine_class(module, database_config)
+        if platform.python_implementation() == "PyPy":
+            import psycopg2cffi as psycopg2  # type: ignore
+        else:
+            import psycopg2  # type: ignore
+
+        return PostgresEngine(psycopg2, database_config)
 
     raise RuntimeError("Unsupported database engine '%s'" % (name,))
 
 
-__all__ = ["create_engine", "IncorrectDatabaseSetup"]
+__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ec5a4d198b..ab0bbe4bd3 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -12,7 +12,94 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
+from typing import Generic, TypeVar
+
+from synapse.storage.types import Connection
 
 
 class IncorrectDatabaseSetup(RuntimeError):
     pass
+
+
+ConnectionType = TypeVar("ConnectionType", bound=Connection)
+
+
+class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
+    def __init__(self, module, database_config: dict):
+        self.module = module
+
+    @property
+    @abc.abstractmethod
+    def single_threaded(self) -> bool:
+        ...
+
+    @property
+    @abc.abstractmethod
+    def can_native_upsert(self) -> bool:
+        """
+        Do we support native UPSERTs?
+        """
+        ...
+
+    @property
+    @abc.abstractmethod
+    def supports_tuple_comparison(self) -> bool:
+        """
+        Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
+        """
+        ...
+
+    @property
+    @abc.abstractmethod
+    def supports_using_any_list(self) -> bool:
+        """
+        Do we support using `a = ANY(?)` and passing a list
+        """
+        ...
+
+    @abc.abstractmethod
+    def check_database(
+        self, db_conn: ConnectionType, allow_outdated_version: bool = False
+    ) -> None:
+        ...
+
+    @abc.abstractmethod
+    def check_new_database(self, txn) -> None:
+        """Gets called when setting up a brand new database. This allows us to
+        apply stricter checks on new databases versus existing database.
+        """
+        ...
+
+    @abc.abstractmethod
+    def convert_param_style(self, sql: str) -> str:
+        ...
+
+    @abc.abstractmethod
+    def on_new_connection(self, db_conn: ConnectionType) -> None:
+        ...
+
+    @abc.abstractmethod
+    def is_deadlock(self, error: Exception) -> bool:
+        ...
+
+    @abc.abstractmethod
+    def is_connection_closed(self, conn: ConnectionType) -> bool:
+        ...
+
+    @abc.abstractmethod
+    def lock_table(self, txn, table: str) -> None:
+        ...
+
+    @abc.abstractmethod
+    def get_next_state_group_id(self, txn) -> int:
+        """Returns an int that can be used as a new state_group ID
+        """
+        ...
+
+    @property
+    @abc.abstractmethod
+    def server_version(self) -> str:
+        """Gets a string giving the server version. For example: '3.22.0'
+        """
+        ...
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index a077345960..6c7d08a6f2 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -15,16 +15,14 @@
 
 import logging
 
-from ._base import IncorrectDatabaseSetup
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
 
 logger = logging.getLogger(__name__)
 
 
-class PostgresEngine(object):
-    single_threaded = False
-
+class PostgresEngine(BaseDatabaseEngine):
     def __init__(self, database_module, database_config):
-        self.module = database_module
+        super().__init__(database_module, database_config)
         self.module.extensions.register_type(self.module.extensions.UNICODE)
 
         # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
@@ -36,6 +34,10 @@ class PostgresEngine(object):
         self.synchronous_commit = database_config.get("synchronous_commit", True)
         self._version = None  # unknown as yet
 
+    @property
+    def single_threaded(self) -> bool:
+        return False
+
     def check_database(self, db_conn, allow_outdated_version: bool = False):
         # Get the version of PostgreSQL that we're using. As per the psycopg2
         # docs: The number is formed by converting the major, minor, and
@@ -53,7 +55,7 @@ class PostgresEngine(object):
             if rows and rows[0][0] != "UTF8":
                 raise IncorrectDatabaseSetup(
                     "Database has incorrect encoding: '%s' instead of 'UTF8'\n"
-                    "See docs/postgres.rst for more information." % (rows[0][0],)
+                    "See docs/postgres.md for more information." % (rows[0][0],)
                 )
 
             txn.execute(
@@ -62,12 +64,16 @@ class PostgresEngine(object):
             collation, ctype = txn.fetchone()
             if collation != "C":
                 logger.warning(
-                    "Database has incorrect collation of %r. Should be 'C'", collation
+                    "Database has incorrect collation of %r. Should be 'C'\n"
+                    "See docs/postgres.md for more information.",
+                    collation,
                 )
 
             if ctype != "C":
                 logger.warning(
-                    "Database has incorrect ctype of %r. Should be 'C'", ctype
+                    "Database has incorrect ctype of %r. Should be 'C'\n"
+                    "See docs/postgres.md for more information.",
+                    ctype,
                 )
 
     def check_new_database(self, txn):
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 641e490697..2bfeefd54e 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,16 +12,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import sqlite3
 import struct
 import threading
 
+from synapse.storage.engines import BaseDatabaseEngine
 
-class Sqlite3Engine(object):
-    single_threaded = True
 
+class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
     def __init__(self, database_module, database_config):
-        self.module = database_module
+        super().__init__(database_module, database_config)
 
         database = database_config.get("args", {}).get("database")
         self._is_in_memory = database in (None, ":memory:",)
@@ -32,6 +32,10 @@ class Sqlite3Engine(object):
         self._current_state_group_id_lock = threading.Lock()
 
     @property
+    def single_threaded(self) -> bool:
+        return True
+
+    @property
     def can_native_upsert(self):
         """
         Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
@@ -68,7 +72,6 @@ class Sqlite3Engine(object):
         return sql
 
     def on_new_connection(self, db_conn):
-
         # We need to import here to avoid an import loop.
         from synapse.storage.prepare_database import prepare_database