summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/6987.misc1
-rw-r--r--synapse/storage/database.py143
-rw-r--r--synapse/storage/engines/__init__.py28
-rw-r--r--synapse/storage/engines/_base.py87
-rw-r--r--synapse/storage/engines/postgres.py12
-rw-r--r--synapse/storage/engines/sqlite.py13
-rw-r--r--synapse/storage/types.py65
-rw-r--r--tox.ini5
8 files changed, 270 insertions, 84 deletions
diff --git a/changelog.d/6987.misc b/changelog.d/6987.misc
new file mode 100644
index 0000000000..7ff74cda55
--- /dev/null
+++ b/changelog.d/6987.misc
@@ -0,0 +1 @@
+Add some type annotations to the database storage classes.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 1953614401..609db40616 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -15,9 +15,9 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 import logging
-import sys
 import time
-from typing import Iterable, Tuple
+from time import monotonic as monotonic_time
+from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple
 
 from six import iteritems, iterkeys, itervalues
 from six.moves import intern, range
@@ -32,24 +32,14 @@ from synapse.config.database import DatabaseConnectionConfig
 from synapse.logging.context import LoggingContext, make_deferred_yieldable
 from synapse.metrics.background_process_metrics import run_as_background_process
 from synapse.storage.background_updates import BackgroundUpdater
-from synapse.storage.engines import PostgresEngine, Sqlite3Engine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3Engine
+from synapse.storage.types import Connection, Cursor
 from synapse.util.stringutils import exception_to_unicode
 
-# import a function which will return a monotonic time, in seconds
-try:
-    # on python 3, use time.monotonic, since time.clock can go backwards
-    from time import monotonic as monotonic_time
-except ImportError:
-    # ... but python 2 doesn't have it
-    from time import clock as monotonic_time
-
 logger = logging.getLogger(__name__)
 
-try:
-    MAX_TXN_ID = sys.maxint - 1
-except AttributeError:
-    # python 3 does not have a maximum int value
-    MAX_TXN_ID = 2 ** 63 - 1
+# python 3 does not have a maximum int value
+MAX_TXN_ID = 2 ** 63 - 1
 
 sql_logger = logging.getLogger("synapse.storage.SQL")
 transaction_logger = logging.getLogger("synapse.storage.txn")
@@ -77,7 +67,7 @@ UNIQUE_INDEX_BACKGROUND_UPDATES = {
 
 
 def make_pool(
-    reactor, db_config: DatabaseConnectionConfig, engine
+    reactor, db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
 ) -> adbapi.ConnectionPool:
     """Get the connection pool for the database.
     """
@@ -90,7 +80,9 @@ def make_pool(
     )
 
 
-def make_conn(db_config: DatabaseConnectionConfig, engine):
+def make_conn(
+    db_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+) -> Connection:
     """Make a new connection to the database and return it.
 
     Returns:
@@ -107,20 +99,27 @@ def make_conn(db_config: DatabaseConnectionConfig, engine):
     return db_conn
 
 
-class LoggingTransaction(object):
+# The type of entry which goes on our after_callbacks and exception_callbacks lists.
+#
+# Python 3.5.2 doesn't support Callable with an ellipsis, so we wrap it in quotes so
+# that mypy sees the type but the runtime python doesn't.
+_CallbackListEntry = Tuple["Callable[..., None]", Iterable[Any], Dict[str, Any]]
+
+
+class LoggingTransaction:
     """An object that almost-transparently proxies for the 'txn' object
     passed to the constructor. Adds logging and metrics to the .execute()
     method.
 
     Args:
         txn: The database transcation object to wrap.
-        name (str): The name of this transactions for logging.
-        database_engine (Sqlite3Engine|PostgresEngine)
-        after_callbacks(list|None): A list that callbacks will be appended to
+        name: The name of this transactions for logging.
+        database_engine
+        after_callbacks: A list that callbacks will be appended to
             that have been added by `call_after` which should be run on
             successful completion of the transaction. None indicates that no
             callbacks should be allowed to be scheduled to run.
-        exception_callbacks(list|None): A list that callbacks will be appended
+        exception_callbacks: A list that callbacks will be appended
             to that have been added by `call_on_exception` which should be run
             if transaction ends with an error. None indicates that no callbacks
             should be allowed to be scheduled to run.
@@ -135,46 +134,67 @@ class LoggingTransaction(object):
     ]
 
     def __init__(
-        self, txn, name, database_engine, after_callbacks=None, exception_callbacks=None
+        self,
+        txn: Cursor,
+        name: str,
+        database_engine: BaseDatabaseEngine,
+        after_callbacks: Optional[List[_CallbackListEntry]] = None,
+        exception_callbacks: Optional[List[_CallbackListEntry]] = None,
     ):
-        object.__setattr__(self, "txn", txn)
-        object.__setattr__(self, "name", name)
-        object.__setattr__(self, "database_engine", database_engine)
-        object.__setattr__(self, "after_callbacks", after_callbacks)
-        object.__setattr__(self, "exception_callbacks", exception_callbacks)
+        self.txn = txn
+        self.name = name
+        self.database_engine = database_engine
+        self.after_callbacks = after_callbacks
+        self.exception_callbacks = exception_callbacks
 
-    def call_after(self, callback, *args, **kwargs):
+    def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
         """Call the given callback on the main twisted thread after the
         transaction has finished. Used to invalidate the caches on the
         correct thread.
         """
+        # if self.after_callbacks is None, that means that whatever constructed the
+        # LoggingTransaction isn't expecting there to be any callbacks; assert that
+        # is not the case.
+        assert self.after_callbacks is not None
         self.after_callbacks.append((callback, args, kwargs))
 
-    def call_on_exception(self, callback, *args, **kwargs):
+    def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+        # if self.exception_callbacks is None, that means that whatever constructed the
+        # LoggingTransaction isn't expecting there to be any callbacks; assert that
+        # is not the case.
+        assert self.exception_callbacks is not None
         self.exception_callbacks.append((callback, args, kwargs))
 
-    def __getattr__(self, name):
-        return getattr(self.txn, name)
+    def fetchall(self) -> List[Tuple]:
+        return self.txn.fetchall()
 
-    def __setattr__(self, name, value):
-        setattr(self.txn, name, value)
+    def fetchone(self) -> Tuple:
+        return self.txn.fetchone()
 
-    def __iter__(self):
+    def __iter__(self) -> Iterator[Tuple]:
         return self.txn.__iter__()
 
+    @property
+    def rowcount(self) -> int:
+        return self.txn.rowcount
+
+    @property
+    def description(self) -> Any:
+        return self.txn.description
+
     def execute_batch(self, sql, args):
         if isinstance(self.database_engine, PostgresEngine):
-            from psycopg2.extras import execute_batch
+            from psycopg2.extras import execute_batch  # type: ignore
 
             self._do_execute(lambda *x: execute_batch(self.txn, *x), sql, args)
         else:
             for val in args:
                 self.execute(sql, val)
 
-    def execute(self, sql, *args):
+    def execute(self, sql: str, *args: Any):
         self._do_execute(self.txn.execute, sql, *args)
 
-    def executemany(self, sql, *args):
+    def executemany(self, sql: str, *args: Any):
         self._do_execute(self.txn.executemany, sql, *args)
 
     def _make_sql_one_line(self, sql):
@@ -207,6 +227,9 @@ class LoggingTransaction(object):
             sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
             sql_query_timer.labels(sql.split()[0]).observe(secs)
 
+    def close(self):
+        self.txn.close()
+
 
 class PerformanceCounters(object):
     def __init__(self):
@@ -251,7 +274,9 @@ class Database(object):
 
     _TXN_ID = 0
 
-    def __init__(self, hs, database_config: DatabaseConnectionConfig, engine):
+    def __init__(
+        self, hs, database_config: DatabaseConnectionConfig, engine: BaseDatabaseEngine
+    ):
         self.hs = hs
         self._clock = hs.get_clock()
         self._database_config = database_config
@@ -259,9 +284,9 @@ class Database(object):
 
         self.updates = BackgroundUpdater(hs, self)
 
-        self._previous_txn_total_time = 0
-        self._current_txn_total_time = 0
-        self._previous_loop_ts = 0
+        self._previous_txn_total_time = 0.0
+        self._current_txn_total_time = 0.0
+        self._previous_loop_ts = 0.0
 
         # TODO(paul): These can eventually be removed once the metrics code
         #   is running in mainline, and we have some nice monitoring frontends
@@ -463,23 +488,23 @@ class Database(object):
             sql_txn_timer.labels(desc).observe(duration)
 
     @defer.inlineCallbacks
-    def runInteraction(self, desc, func, *args, **kwargs):
+    def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
         """Starts a transaction on the database and runs a given function
 
         Arguments:
-            desc (str): description of the transaction, for logging and metrics
-            func (func): callback function, which will be called with a
+            desc: description of the transaction, for logging and metrics
+            func: callback function, which will be called with a
                 database transaction (twisted.enterprise.adbapi.Transaction) as
                 its first argument, followed by `args` and `kwargs`.
 
-            args (list): positional args to pass to `func`
-            kwargs (dict): named args to pass to `func`
+            args: positional args to pass to `func`
+            kwargs: named args to pass to `func`
 
         Returns:
             Deferred: The result of func
         """
-        after_callbacks = []
-        exception_callbacks = []
+        after_callbacks = []  # type: List[_CallbackListEntry]
+        exception_callbacks = []  # type: List[_CallbackListEntry]
 
         if LoggingContext.current_context() == LoggingContext.sentinel:
             logger.warning("Starting db txn '%s' from sentinel context", desc)
@@ -505,15 +530,15 @@ class Database(object):
         return result
 
     @defer.inlineCallbacks
-    def runWithConnection(self, func, *args, **kwargs):
+    def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
         """Wraps the .runWithConnection() method on the underlying db_pool.
 
         Arguments:
-            func (func): callback function, which will be called with a
+            func: callback function, which will be called with a
                 database connection (twisted.enterprise.adbapi.Connection) as
                 its first argument, followed by `args` and `kwargs`.
-            args (list): positional args to pass to `func`
-            kwargs (dict): named args to pass to `func`
+            args: positional args to pass to `func`
+            kwargs: named args to pass to `func`
 
         Returns:
             Deferred: The result of func
@@ -800,7 +825,7 @@ class Database(object):
                 return False
 
         # We didn't find any existing rows, so insert a new one
-        allvalues = {}
+        allvalues = {}  # type: Dict[str, Any]
         allvalues.update(keyvalues)
         allvalues.update(values)
         allvalues.update(insertion_values)
@@ -829,7 +854,7 @@ class Database(object):
         Returns:
             None
         """
-        allvalues = {}
+        allvalues = {}  # type: Dict[str, Any]
         allvalues.update(keyvalues)
         allvalues.update(insertion_values)
 
@@ -916,7 +941,7 @@ class Database(object):
         Returns:
             None
         """
-        allnames = []
+        allnames = []  # type: List[str]
         allnames.extend(key_names)
         allnames.extend(value_names)
 
@@ -1100,7 +1125,7 @@ class Database(object):
             keyvalues : dict of column names and values to select the rows with
             retcols : list of strings giving the names of the columns to return
         """
-        results = []
+        results = []  # type: List[Dict[str, Any]]
 
         if not iterable:
             return results
@@ -1439,7 +1464,7 @@ class Database(object):
             raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
 
         where_clause = "WHERE " if filters or keyvalues else ""
-        arg_list = []
+        arg_list = []  # type: List[Any]
         if filters:
             where_clause += " AND ".join("%s LIKE ?" % (k,) for k in filters)
             arg_list += list(filters.values())
diff --git a/synapse/storage/engines/__init__.py b/synapse/storage/engines/__init__.py
index 9d2d519922..035f9ea6e9 100644
--- a/synapse/storage/engines/__init__.py
+++ b/synapse/storage/engines/__init__.py
@@ -12,29 +12,31 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
-import importlib
 import platform
 
-from ._base import IncorrectDatabaseSetup
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
 from .postgres import PostgresEngine
 from .sqlite import Sqlite3Engine
 
-SUPPORTED_MODULE = {"sqlite3": Sqlite3Engine, "psycopg2": PostgresEngine}
-
 
-def create_engine(database_config):
+def create_engine(database_config) -> BaseDatabaseEngine:
     name = database_config["name"]
-    engine_class = SUPPORTED_MODULE.get(name, None)
 
-    if engine_class:
+    if name == "sqlite3":
+        import sqlite3
+
+        return Sqlite3Engine(sqlite3, database_config)
+
+    if name == "psycopg2":
         # pypy requires psycopg2cffi rather than psycopg2
-        if name == "psycopg2" and platform.python_implementation() == "PyPy":
-            name = "psycopg2cffi"
-        module = importlib.import_module(name)
-        return engine_class(module, database_config)
+        if platform.python_implementation() == "PyPy":
+            import psycopg2cffi as psycopg2  # type: ignore
+        else:
+            import psycopg2  # type: ignore
+
+        return PostgresEngine(psycopg2, database_config)
 
     raise RuntimeError("Unsupported database engine '%s'" % (name,))
 
 
-__all__ = ["create_engine", "IncorrectDatabaseSetup"]
+__all__ = ["create_engine", "BaseDatabaseEngine", "IncorrectDatabaseSetup"]
diff --git a/synapse/storage/engines/_base.py b/synapse/storage/engines/_base.py
index ec5a4d198b..ab0bbe4bd3 100644
--- a/synapse/storage/engines/_base.py
+++ b/synapse/storage/engines/_base.py
@@ -12,7 +12,94 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
+import abc
+from typing import Generic, TypeVar
+
+from synapse.storage.types import Connection
 
 
 class IncorrectDatabaseSetup(RuntimeError):
     pass
+
+
+ConnectionType = TypeVar("ConnectionType", bound=Connection)
+
+
+class BaseDatabaseEngine(Generic[ConnectionType], metaclass=abc.ABCMeta):
+    def __init__(self, module, database_config: dict):
+        self.module = module
+
+    @property
+    @abc.abstractmethod
+    def single_threaded(self) -> bool:
+        ...
+
+    @property
+    @abc.abstractmethod
+    def can_native_upsert(self) -> bool:
+        """
+        Do we support native UPSERTs?
+        """
+        ...
+
+    @property
+    @abc.abstractmethod
+    def supports_tuple_comparison(self) -> bool:
+        """
+        Do we support comparing tuples, i.e. `(a, b) > (c, d)`?
+        """
+        ...
+
+    @property
+    @abc.abstractmethod
+    def supports_using_any_list(self) -> bool:
+        """
+        Do we support using `a = ANY(?)` and passing a list
+        """
+        ...
+
+    @abc.abstractmethod
+    def check_database(
+        self, db_conn: ConnectionType, allow_outdated_version: bool = False
+    ) -> None:
+        ...
+
+    @abc.abstractmethod
+    def check_new_database(self, txn) -> None:
+        """Gets called when setting up a brand new database. This allows us to
+        apply stricter checks on new databases versus existing database.
+        """
+        ...
+
+    @abc.abstractmethod
+    def convert_param_style(self, sql: str) -> str:
+        ...
+
+    @abc.abstractmethod
+    def on_new_connection(self, db_conn: ConnectionType) -> None:
+        ...
+
+    @abc.abstractmethod
+    def is_deadlock(self, error: Exception) -> bool:
+        ...
+
+    @abc.abstractmethod
+    def is_connection_closed(self, conn: ConnectionType) -> bool:
+        ...
+
+    @abc.abstractmethod
+    def lock_table(self, txn, table: str) -> None:
+        ...
+
+    @abc.abstractmethod
+    def get_next_state_group_id(self, txn) -> int:
+        """Returns an int that can be used as a new state_group ID
+        """
+        ...
+
+    @property
+    @abc.abstractmethod
+    def server_version(self) -> str:
+        """Gets a string giving the server version. For example: '3.22.0'
+        """
+        ...
diff --git a/synapse/storage/engines/postgres.py b/synapse/storage/engines/postgres.py
index 53b3f372b0..6c7d08a6f2 100644
--- a/synapse/storage/engines/postgres.py
+++ b/synapse/storage/engines/postgres.py
@@ -15,16 +15,14 @@
 
 import logging
 
-from ._base import IncorrectDatabaseSetup
+from ._base import BaseDatabaseEngine, IncorrectDatabaseSetup
 
 logger = logging.getLogger(__name__)
 
 
-class PostgresEngine(object):
-    single_threaded = False
-
+class PostgresEngine(BaseDatabaseEngine):
     def __init__(self, database_module, database_config):
-        self.module = database_module
+        super().__init__(database_module, database_config)
         self.module.extensions.register_type(self.module.extensions.UNICODE)
 
         # Disables passing `bytes` to txn.execute, c.f. #6186. If you do
@@ -36,6 +34,10 @@ class PostgresEngine(object):
         self.synchronous_commit = database_config.get("synchronous_commit", True)
         self._version = None  # unknown as yet
 
+    @property
+    def single_threaded(self) -> bool:
+        return False
+
     def check_database(self, db_conn, allow_outdated_version: bool = False):
         # Get the version of PostgreSQL that we're using. As per the psycopg2
         # docs: The number is formed by converting the major, minor, and
diff --git a/synapse/storage/engines/sqlite.py b/synapse/storage/engines/sqlite.py
index 641e490697..2bfeefd54e 100644
--- a/synapse/storage/engines/sqlite.py
+++ b/synapse/storage/engines/sqlite.py
@@ -12,16 +12,16 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-
+import sqlite3
 import struct
 import threading
 
+from synapse.storage.engines import BaseDatabaseEngine
 
-class Sqlite3Engine(object):
-    single_threaded = True
 
+class Sqlite3Engine(BaseDatabaseEngine[sqlite3.Connection]):
     def __init__(self, database_module, database_config):
-        self.module = database_module
+        super().__init__(database_module, database_config)
 
         database = database_config.get("args", {}).get("database")
         self._is_in_memory = database in (None, ":memory:",)
@@ -32,6 +32,10 @@ class Sqlite3Engine(object):
         self._current_state_group_id_lock = threading.Lock()
 
     @property
+    def single_threaded(self) -> bool:
+        return True
+
+    @property
     def can_native_upsert(self):
         """
         Do we support native UPSERTs? This requires SQLite3 3.24+, plus some
@@ -68,7 +72,6 @@ class Sqlite3Engine(object):
         return sql
 
     def on_new_connection(self, db_conn):
-
         # We need to import here to avoid an import loop.
         from synapse.storage.prepare_database import prepare_database
 
diff --git a/synapse/storage/types.py b/synapse/storage/types.py
new file mode 100644
index 0000000000..daff81c5ee
--- /dev/null
+++ b/synapse/storage/types.py
@@ -0,0 +1,65 @@
+# -*- coding: utf-8 -*-
+# Copyright 2020 The Matrix.org Foundation C.I.C.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Iterable, Iterator, List, Tuple
+
+from typing_extensions import Protocol
+
+
+"""
+Some very basic protocol definitions for the DB-API2 classes specified in PEP-249
+"""
+
+
+class Cursor(Protocol):
+    def execute(self, sql: str, parameters: Iterable[Any] = ...) -> Any:
+        ...
+
+    def executemany(self, sql: str, parameters: Iterable[Iterable[Any]]) -> Any:
+        ...
+
+    def fetchall(self) -> List[Tuple]:
+        ...
+
+    def fetchone(self) -> Tuple:
+        ...
+
+    @property
+    def description(self) -> Any:
+        return None
+
+    @property
+    def rowcount(self) -> int:
+        return 0
+
+    def __iter__(self) -> Iterator[Tuple]:
+        ...
+
+    def close(self) -> None:
+        ...
+
+
+class Connection(Protocol):
+    def cursor(self) -> Cursor:
+        ...
+
+    def close(self) -> None:
+        ...
+
+    def commit(self) -> None:
+        ...
+
+    def rollback(self, *args, **kwargs) -> None:
+        ...
diff --git a/tox.ini b/tox.ini
index 4ccfde01b5..6521535137 100644
--- a/tox.ini
+++ b/tox.ini
@@ -168,7 +168,6 @@ commands=
     coverage html
 
 [testenv:mypy]
-basepython = python3.7
 skip_install = True
 deps =
     {[base]deps}
@@ -179,7 +178,8 @@ env =
 extras = all
 commands = mypy \
             synapse/api \
-            synapse/config/ \
+            synapse/appservice \
+            synapse/config \
             synapse/events/spamcheck.py \
             synapse/federation/sender \
             synapse/federation/transport \
@@ -192,6 +192,7 @@ commands = mypy \
             synapse/rest \
             synapse/spam_checker_api \
             synapse/storage/engines \
+            synapse/storage/database.py \
             synapse/streams
 
 # To find all folders that pass mypy you run: