diff options
Diffstat (limited to 'synapse/storage/engines')
-rw-r--r-- | synapse/storage/engines/__init__.py | 28 | ||||
-rw-r--r-- | synapse/storage/engines/_base.py | 87 | ||||
-rw-r--r-- | synapse/storage/engines/postgres.py | 108 | ||||
-rw-r--r-- | synapse/storage/engines/sqlite.py | 55 |
4 files changed, 237 insertions, 41 deletions
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py index 9d2d519922..035f9ea6e9 100644 --- a/synapse/storage/engines/__init__.py +++ b/synapse/storage/engines/__init__.py @@ -12,29 +12,31 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -import importlib import platform -from ._base import IncorrectDatabaseSetup +from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup from .postgres import PostgresEngine from .sqlite import Sqlite3Engine -SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine} - -def create_engine(database_config): +def create_engine(database_config) -> BaseDatabaseEngine: name = database_config["name"] - engine_class = SUPPORTED_MODULE.get(name, None) - if engine_class: + if name == "sqlite3": + import sqlite3 + + return Sqlite3Engine(sqlite3, database_config) + + if name == "psycopg2": # pypy requires psycopg2cffi rather than psycopg2 - if name == "psycopg2" and platform.python_implementation() == "PyPy": - name = "psycopg2cffi" - module = importlib.import_module(name) - return engine_class(module, database_config) + if platform.python_implementation() == "PyPy": + import psycopg2cffi as psycopg2 # type: ignore + else: + import psycopg2 # type: ignore + + return PostgresEngine(psycopg2, database_config) raise RuntimeError("Unsupported database engine '%s'" % (name,)) -__all__ = ["create_engine", "IncorrectDatabaseSetup"] +__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"] diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py index ec5a4d198b..ab0bbe4bd3 100644 --- a/synapse/storage/engines/_base.py +++ b/synapse/storage/engines/_base.py @@ -12,7 +12,94 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import abc +from typing import Generic, TypeVar + +from synapse.storage.types import Connection class IncorrectDatabaseSetup(RuntimeError): pass + + +ConnectionType = TypeVar("ConnectionType", bound=Connection) + + +class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta): + def __init__(self, module, database_config: dict): + self.module = module + + @property + @abc.abstractmethod + def single_threaded(self) -> bool: + ... + + @property + @abc.abstractmethod + def can_native_upsert(self) -> bool: + """ + Do we support native UPSERTs? + """ + ... + + @property + @abc.abstractmethod + def supports_tuple_comparison(self) -> bool: + """ + Do we support comparing tuples, i.e. `(a, b) > (c, d)`? + """ + ... + + @property + @abc.abstractmethod + def supports_using_any_list(self) -> bool: + """ + Do we support using `a = ANY(?)` and passing a list + """ + ... + + @abc.abstractmethod + def check_database( + self, db_conn: ConnectionType, allow_outdated_version: bool = False + ) -> None: + ... + + @abc.abstractmethod + def check_new_database(self, txn) -> None: + """Gets called when setting up a brand new database. This allows us to + apply stricter checks on new databases versus existing database. + """ + ... + + @abc.abstractmethod + def convert_param_style(self, sql: str) -> str: + ... + + @abc.abstractmethod + def on_new_connection(self, db_conn: ConnectionType) -> None: + ... + + @abc.abstractmethod + def is_deadlock(self, error: Exception) -> bool: + ... + + @abc.abstractmethod + def is_connection_closed(self, conn: ConnectionType) -> bool: + ... + + @abc.abstractmethod + def lock_table(self, txn, table: str) -> None: + ... + + @abc.abstractmethod + def get_next_state_group_id(self, txn) -> int: + """Returns an int that can be used as a new state_group ID + """ + ... + + @property + @abc.abstractmethod + def server_version(self) -> str: + """Gets a string giving the server version. For example: '3.22.0' + """ + ... diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py index 289b6bc281..6c7d08a6f2 100644 --- a/synapse/storage/engines/postgres.py +++ b/synapse/storage/engines/postgres.py @@ -13,32 +13,32 @@ # See the License for the specific language governing permissions and # limitations under the License. -from ._base import IncorrectDatabaseSetup +import logging +from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup -class PostgresEngine(object): - single_threaded = False +logger = logging.getLogger(__name__) + +class PostgresEngine(BaseDatabaseEngine): def __init__(self, database_module, database_config): - self.module = database_module + super().__init__(database_module, database_config) self.module.extensions.register_type(self.module.extensions.UNICODE) - self.synchronous_commit = database_config.get("synchronous_commit", True) - self._version = None # unknown as yet - def check_database(self, txn): - txn.execute("SHOW SERVER_ENCODING") - rows = txn.fetchall() - if rows and rows[0][0] != "UTF8": - raise IncorrectDatabaseSetup( - "Database has incorrect encoding: '%s' instead of 'UTF8'\n" - "See docs/postgres.rst for more information." % (rows[0][0],) - ) + # Disables passing `bytes` to txn.execute, c.f. #6186. If you do + # actually want to use bytes than wrap it in `bytearray`. + def _disable_bytes_adapter(_): + raise Exception("Passing bytes to DB is disabled.") - def convert_param_style(self, sql): - return sql.replace("?", "%s") + self.module.extensions.register_adapter(bytes, _disable_bytes_adapter) + self.synchronous_commit = database_config.get("synchronous_commit", True) + self._version = None # unknown as yet - def on_new_connection(self, db_conn): + @property + def single_threaded(self) -> bool: + return False + def check_database(self, db_conn, allow_outdated_version: bool = False): # Get the version of PostgreSQL that we're using. As per the psycopg2 # docs: The number is formed by converting the major, minor, and # revision numbers into two-decimal-digit numbers and appending them @@ -46,9 +46,64 @@ class PostgresEngine(object): self._version = db_conn.server_version # Are we on a supported PostgreSQL version? - if self._version < 90500: + if not allow_outdated_version and self._version < 90500: raise RuntimeError("Synapse requires PostgreSQL 9.5+ or above.") + with db_conn.cursor() as txn: + txn.execute("SHOW SERVER_ENCODING") + rows = txn.fetchall() + if rows and rows[0][0] != "UTF8": + raise IncorrectDatabaseSetup( + "Database has incorrect encoding: '%s' instead of 'UTF8'\n" + "See docs/postgres.md for more information." % (rows[0][0],) + ) + + txn.execute( + "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()" + ) + collation, ctype = txn.fetchone() + if collation != "C": + logger.warning( + "Database has incorrect collation of %r. Should be 'C'\n" + "See docs/postgres.md for more information.", + collation, + ) + + if ctype != "C": + logger.warning( + "Database has incorrect ctype of %r. Should be 'C'\n" + "See docs/postgres.md for more information.", + ctype, + ) + + def check_new_database(self, txn): + """Gets called when setting up a brand new database. This allows us to + apply stricter checks on new databases versus existing database. + """ + + txn.execute( + "SELECT datcollate, datctype FROM pg_database WHERE datname = current_database()" + ) + collation, ctype = txn.fetchone() + + errors = [] + + if collation != "C": + errors.append(" - 'COLLATE' is set to %r. Should be 'C'" % (collation,)) + + if ctype != "C": + errors.append(" - 'CTYPE' is set to %r. Should be 'C'" % (collation,)) + + if errors: + raise IncorrectDatabaseSetup( + "Database is incorrectly configured:\n\n%s\n\n" + "See docs/postgres.md for more information." % ("\n".join(errors)) + ) + + def convert_param_style(self, sql): + return sql.replace("?", "%s") + + def on_new_connection(self, db_conn): db_conn.set_isolation_level( self.module.extensions.ISOLATION_LEVEL_REPEATABLE_READ ) @@ -72,6 +127,19 @@ class PostgresEngine(object): """ return True + @property + def supports_tuple_comparison(self): + """ + Do we support comparing tuples, i.e. `(a, b) > (c, d)`? + """ + return True + + @property + def supports_using_any_list(self): + """Do we support using `a = ANY(?)` and passing a list + """ + return True + def is_deadlock(self, error): if isinstance(error, self.module.DatabaseError): # https://www.postgresql.org/docs/current/static/errcodes-appendix.html @@ -99,8 +167,8 @@ class PostgresEngine(object): Returns: string """ - # note that this is a bit of a hack because it relies on on_new_connection - # having been called at least once. Still, that should be a safe bet here. + # note that this is a bit of a hack because it relies on check_database + # having been called. Still, that should be a safe bet here. numver = self._version assert numver is not None diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py index e9b9caa49a..215a949442 100644 --- a/synapse/storage/engines/sqlite.py +++ b/synapse/storage/engines/sqlite.py @@ -12,18 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import struct import threading +import typing -from synapse.storage.prepare_database import prepare_database +from synapse.storage.engines import BaseDatabaseEngine +if typing.TYPE_CHECKING: + import sqlite3 # noqa: F401 -class Sqlite3Engine(object): - single_threaded = True +class Sqlite3Engine(BaseDatabaseEngine["sqlite3.Connection"]): def __init__(self, database_module, database_config): - self.module = database_module + super().__init__(database_module, database_config) + + database = database_config.get("args", {}).get("database") + self._is_in_memory = database in (None, ":memory:",) # The current max state_group, or None if we haven't looked # in the DB yet. @@ -31,6 +35,10 @@ class Sqlite3Engine(object): self._current_state_group_id_lock = threading.Lock() @property + def single_threaded(self) -> bool: + return True + + @property def can_native_upsert(self): """ Do we support native UPSERTs? This requires SQLite3 3.24+, plus some @@ -38,15 +46,46 @@ class Sqlite3Engine(object): """ return self.module.sqlite_version_info >= (3, 24, 0) - def check_database(self, txn): - pass + @property + def supports_tuple_comparison(self): + """ + Do we support comparing tuples, i.e. `(a, b) > (c, d)`? This requires + SQLite 3.15+. + """ + return self.module.sqlite_version_info >= (3, 15, 0) + + @property + def supports_using_any_list(self): + """Do we support using `a = ANY(?)` and passing a list + """ + return False + + def check_database(self, db_conn, allow_outdated_version: bool = False): + if not allow_outdated_version: + version = self.module.sqlite_version_info + if version < (3, 11, 0): + raise RuntimeError("Synapse requires sqlite 3.11 or above.") + + def check_new_database(self, txn): + """Gets called when setting up a brand new database. This allows us to + apply stricter checks on new databases versus existing database. + """ def convert_param_style(self, sql): return sql def on_new_connection(self, db_conn): - prepare_database(db_conn, self, config=None) + # We need to import here to avoid an import loop. + from synapse.storage.prepare_database import prepare_database + + if self._is_in_memory: + # In memory databases need to be rebuilt each time. Ideally we'd + # reuse the same connection as we do when starting up, but that + # would involve using adbapi before we have started the reactor. + prepare_database(db_conn, self, config=None) + db_conn.create_function("rank", 1, _rank) + db_conn.execute("PRAGMA foreign_keys = ON;") def is_deadlock(self, error): return False |