summary refs log tree commit diff
path: root/synapse/storage/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r--synapse/storage/database.py851
1 files changed, 542 insertions, 309 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b112ff3df2..ed8a9bffb1 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -16,6 +16,7 @@
 # limitations under the License.
 import logging
 import time
+from sys import intern
 from time import monotonic as monotonic_time
 from typing import (
     Any,
@@ -27,15 +28,14 @@ from typing import (
     Optional,
     Tuple,
     TypeVar,
+    cast,
+    overload,
 )
 
-from six import iteritems, iterkeys, itervalues
-from six.moves import intern, range
-
 from prometheus_client import Histogram
+from typing_extensions import Literal
 
 from twisted.enterprise import adbapi
-from twisted.internet import defer
 
 from synapse.api.errors import StoreError
 from synapse.config.database import DatabaseConnectionConfig
@@ -51,11 +51,11 @@ from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine, Sqlite3E
 from synapse.storage.types import Connection, Cursor
 from synapse.types import Collection
 
-logger = logging.getLogger(__name__)
-
 # python 3 does not have a maximum int value
 MAX_TXN_ID = 2 ** 63 - 1
 
+logger = logging.getLogger(__name__)
+
 sql_logger = logging.getLogger("synapse.storage.SQL")
 transaction_logger = logging.getLogger("synapse.storage.txn")
 perf_logger = logging.getLogger("synapse.storage.TIME")
@@ -127,7 +127,7 @@ class LoggingTransaction:
     method.
 
     Args:
-        txn: The database transcation object to wrap.
+        txn: The database transaction object to wrap.
         name: The name of this transactions for logging.
         database_engine
         after_callbacks: A list that callbacks will be appended to
@@ -162,7 +162,7 @@ class LoggingTransaction:
         self.after_callbacks = after_callbacks
         self.exception_callbacks = exception_callbacks
 
-    def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
+    def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any):
         """Call the given callback on the main twisted thread after the
         transaction has finished. Used to invalidate the caches on the
         correct thread.
@@ -173,7 +173,9 @@ class LoggingTransaction:
         assert self.after_callbacks is not None
         self.after_callbacks.append((callback, args, kwargs))
 
-    def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+    def call_on_exception(
+        self, callback: "Callable[..., None]", *args: Any, **kwargs: Any
+    ):
         # if self.exception_callbacks is None, that means that whatever constructed the
         # LoggingTransaction isn't expecting there to be any callbacks; assert that
         # is not the case.
@@ -197,7 +199,7 @@ class LoggingTransaction:
     def description(self) -> Any:
         return self.txn.description
 
-    def execute_batch(self, sql, args):
+    def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
         if isinstance(self.database_engine, PostgresEngine):
             from psycopg2.extras import execute_batch  # type: ignore
 
@@ -206,17 +208,17 @@ class LoggingTransaction:
             for val in args:
                 self.execute(sql, val)
 
-    def execute(self, sql: str, *args: Any):
+    def execute(self, sql: str, *args: Any) -> None:
         self._do_execute(self.txn.execute, sql, *args)
 
-    def executemany(self, sql: str, *args: Any):
+    def executemany(self, sql: str, *args: Any) -> None:
         self._do_execute(self.txn.executemany, sql, *args)
 
     def _make_sql_one_line(self, sql: str) -> str:
         "Strip newlines out of SQL so that the loggers in the DB are on one line"
         return " ".join(line.strip() for line in sql.splitlines() if line.strip())
 
-    def _do_execute(self, func, sql, *args):
+    def _do_execute(self, func, sql: str, *args: Any) -> None:
         sql = self._make_sql_one_line(sql)
 
         # TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -235,31 +237,31 @@ class LoggingTransaction:
         try:
             return func(sql, *args)
         except Exception as e:
-            logger.debug("[SQL FAIL] {%s} %s", self.name, e)
+            sql_logger.debug("[SQL FAIL] {%s} %s", self.name, e)
             raise
         finally:
             secs = time.time() - start
             sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
             sql_query_timer.labels(sql.split()[0]).observe(secs)
 
-    def close(self):
+    def close(self) -> None:
         self.txn.close()
 
 
-class PerformanceCounters(object):
+class PerformanceCounters:
     def __init__(self):
         self.current_counters = {}
         self.previous_counters = {}
 
-    def update(self, key, duration_secs):
+    def update(self, key: str, duration_secs: float) -> None:
         count, cum_time = self.current_counters.get(key, (0, 0))
         count += 1
         cum_time += duration_secs
         self.current_counters[key] = (count, cum_time)
 
-    def interval(self, interval_duration_secs, limit=3):
+    def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
         counters = []
-        for name, (count, cum_time) in iteritems(self.current_counters):
+        for name, (count, cum_time) in self.current_counters.items():
             prev_count, prev_time = self.previous_counters.get(name, (0, 0))
             counters.append(
                 (
@@ -281,7 +283,10 @@ class PerformanceCounters(object):
         return top_n_counters
 
 
-class Database(object):
+R = TypeVar("R")
+
+
+class DatabasePool:
     """Wraps a single physical database and connection pool.
 
     A single database may be used by multiple data stores.
@@ -329,13 +334,12 @@ class Database(object):
                 self._check_safe_to_upsert,
             )
 
-    def is_running(self):
+    def is_running(self) -> bool:
         """Is the database pool currently running
         """
         return self._db_pool.running
 
-    @defer.inlineCallbacks
-    def _check_safe_to_upsert(self):
+    async def _check_safe_to_upsert(self) -> None:
         """
         Is it safe to use native UPSERT?
 
@@ -344,7 +348,7 @@ class Database(object):
 
         If the background updates have not completed, wait 15 sec and check again.
         """
-        updates = yield self.simple_select_list(
+        updates = await self.simple_select_list(
             "background_updates",
             keyvalues=None,
             retcols=["update_name"],
@@ -366,7 +370,7 @@ class Database(object):
                 self._check_safe_to_upsert,
             )
 
-    def start_profiling(self):
+    def start_profiling(self) -> None:
         self._previous_loop_ts = monotonic_time()
 
         def loop():
@@ -390,8 +394,15 @@ class Database(object):
         self._clock.looping_call(loop, 10000)
 
     def new_transaction(
-        self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
-    ):
+        self,
+        conn: Connection,
+        desc: str,
+        after_callbacks: List[_CallbackListEntry],
+        exception_callbacks: List[_CallbackListEntry],
+        func: "Callable[..., R]",
+        *args: Any,
+        **kwargs: Any
+    ) -> R:
         start = monotonic_time()
         txn_id = self._TXN_ID
 
@@ -421,7 +432,7 @@ class Database(object):
                 except self.engine.module.OperationalError as e:
                     # This can happen if the database disappears mid
                     # transaction.
-                    logger.warning(
+                    transaction_logger.warning(
                         "[TXN OPERROR] {%s} %s %d/%d", name, e, i, N,
                     )
                     if i < N:
@@ -429,18 +440,20 @@ class Database(object):
                         try:
                             conn.rollback()
                         except self.engine.module.Error as e1:
-                            logger.warning("[TXN EROLL] {%s} %s", name, e1)
+                            transaction_logger.warning("[TXN EROLL] {%s} %s", name, e1)
                         continue
                     raise
                 except self.engine.module.DatabaseError as e:
                     if self.engine.is_deadlock(e):
-                        logger.warning("[TXN DEADLOCK] {%s} %d/%d", name, i, N)
+                        transaction_logger.warning(
+                            "[TXN DEADLOCK] {%s} %d/%d", name, i, N
+                        )
                         if i < N:
                             i += 1
                             try:
                                 conn.rollback()
                             except self.engine.module.Error as e1:
-                                logger.warning(
+                                transaction_logger.warning(
                                     "[TXN EROLL] {%s} %s", name, e1,
                                 )
                             continue
@@ -480,7 +493,7 @@ class Database(object):
                     # [2]: https://github.com/python/cpython/blob/v3.8.0/Modules/_sqlite/cursor.c#L236
                     cursor.close()
         except Exception as e:
-            logger.debug("[TXN FAIL] {%s} %s", name, e)
+            transaction_logger.debug("[TXN FAIL] {%s} %s", name, e)
             raise
         finally:
             end = monotonic_time()
@@ -494,8 +507,9 @@ class Database(object):
             self._txn_perf_counters.update(desc, duration)
             sql_txn_timer.labels(desc).observe(duration)
 
-    @defer.inlineCallbacks
-    def runInteraction(self, desc: str, func: Callable, *args: Any, **kwargs: Any):
+    async def runInteraction(
+        self, desc: str, func: "Callable[..., R]", *args: Any, **kwargs: Any
+    ) -> R:
         """Starts a transaction on the database and runs a given function
 
         Arguments:
@@ -508,7 +522,7 @@ class Database(object):
             kwargs: named args to pass to `func`
 
         Returns:
-            Deferred: The result of func
+            The result of func
         """
         after_callbacks = []  # type: List[_CallbackListEntry]
         exception_callbacks = []  # type: List[_CallbackListEntry]
@@ -517,7 +531,7 @@ class Database(object):
             logger.warning("Starting db txn '%s' from sentinel context", desc)
 
         try:
-            result = yield self.runWithConnection(
+            result = await self.runWithConnection(
                 self.new_transaction,
                 desc,
                 after_callbacks,
@@ -534,10 +548,11 @@ class Database(object):
                 after_callback(*after_args, **after_kwargs)
             raise
 
-        return result
+        return cast(R, result)
 
-    @defer.inlineCallbacks
-    def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
+    async def runWithConnection(
+        self, func: "Callable[..., R]", *args: Any, **kwargs: Any
+    ) -> R:
         """Wraps the .runWithConnection() method on the underlying db_pool.
 
         Arguments:
@@ -548,7 +563,7 @@ class Database(object):
             kwargs: named args to pass to `func`
 
         Returns:
-            Deferred: The result of func
+            The result of func
         """
         parent_context = current_context()  # type: Optional[LoggingContextOrSentinel]
         if not parent_context:
@@ -571,18 +586,16 @@ class Database(object):
 
                 return func(conn, *args, **kwargs)
 
-        result = yield make_deferred_yieldable(
+        return await make_deferred_yieldable(
             self._db_pool.runWithConnection(inner_func, *args, **kwargs)
         )
 
-        return result
-
     @staticmethod
-    def cursor_to_dict(cursor):
+    def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
         """Converts a SQL cursor into an list of dicts.
 
         Args:
-            cursor : The DBAPI cursor which has executed a query.
+            cursor: The DBAPI cursor which has executed a query.
         Returns:
             A list of dicts where the key is the column header.
         """
@@ -590,10 +603,29 @@ class Database(object):
         results = [dict(zip(col_headers, row)) for row in cursor]
         return results
 
-    def execute(self, desc, decoder, query, *args):
+    @overload
+    async def execute(
+        self, desc: str, decoder: Literal[None], query: str, *args: Any
+    ) -> List[Tuple[Any, ...]]:
+        ...
+
+    @overload
+    async def execute(
+        self, desc: str, decoder: Callable[[Cursor], R], query: str, *args: Any
+    ) -> R:
+        ...
+
+    async def execute(
+        self,
+        desc: str,
+        decoder: Optional[Callable[[Cursor], R]],
+        query: str,
+        *args: Any
+    ) -> R:
         """Runs a single query for a result set.
 
         Args:
+            desc: description of the transaction, for logging and metrics
             decoder - The function which can resolve the cursor results to
                 something meaningful.
             query - The query string to execute
@@ -609,29 +641,33 @@ class Database(object):
             else:
                 return txn.fetchall()
 
-        return self.runInteraction(desc, interaction)
+        return await self.runInteraction(desc, interaction)
 
     # "Simple" SQL API methods that operate on a single table with no JOINs,
     # no complex WHERE clauses, just a dict of values for columns.
 
-    @defer.inlineCallbacks
-    def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+    async def simple_insert(
+        self,
+        table: str,
+        values: Dict[str, Any],
+        or_ignore: bool = False,
+        desc: str = "simple_insert",
+    ) -> bool:
         """Executes an INSERT query on the named table.
 
         Args:
-            table : string giving the table name
-            values : dict of new column names and values for them
-            or_ignore : bool stating whether an exception should be raised
+            table: string giving the table name
+            values: dict of new column names and values for them
+            or_ignore: bool stating whether an exception should be raised
                 when a conflicting row already exists. If True, False will be
                 returned by the function instead
-            desc : string giving a description of the transaction
+            desc: description of the transaction, for logging and metrics
 
         Returns:
-            bool: Whether the row was inserted or not. Only useful when
-            `or_ignore` is True
+             Whether the row was inserted or not. Only useful when `or_ignore` is True
         """
         try:
-            yield self.runInteraction(desc, self.simple_insert_txn, table, values)
+            await self.runInteraction(desc, self.simple_insert_txn, table, values)
         except self.engine.module.IntegrityError:
             # We have to do or_ignore flag at this layer, since we can't reuse
             # a cursor after we receive an error from the db.
@@ -641,7 +677,9 @@ class Database(object):
         return True
 
     @staticmethod
-    def simple_insert_txn(txn, table, values):
+    def simple_insert_txn(
+        txn: LoggingTransaction, table: str, values: Dict[str, Any]
+    ) -> None:
         keys, vals = zip(*values.items())
 
         sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -652,11 +690,29 @@ class Database(object):
 
         txn.execute(sql, vals)
 
-    def simple_insert_many(self, table, values, desc):
-        return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
+    async def simple_insert_many(
+        self, table: str, values: List[Dict[str, Any]], desc: str
+    ) -> None:
+        """Executes an INSERT query on the named table.
+
+        Args:
+            table: string giving the table name
+            values: dict of new column names and values for them
+            desc: description of the transaction, for logging and metrics
+        """
+        await self.runInteraction(desc, self.simple_insert_many_txn, table, values)
 
     @staticmethod
-    def simple_insert_many_txn(txn, table, values):
+    def simple_insert_many_txn(
+        txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
+    ) -> None:
+        """Executes an INSERT query on the named table.
+
+        Args:
+            txn: The transaction to use.
+            table: string giving the table name
+            values: dict of new column names and values for them
+        """
         if not values:
             return
 
@@ -684,16 +740,15 @@ class Database(object):
 
         txn.executemany(sql, vals)
 
-    @defer.inlineCallbacks
-    def simple_upsert(
+    async def simple_upsert(
         self,
-        table,
-        keyvalues,
-        values,
-        insertion_values={},
-        desc="simple_upsert",
-        lock=True,
-    ):
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+        desc: str = "simple_upsert",
+        lock: bool = True,
+    ) -> Optional[bool]:
         """
 
         `lock` should generally be set to True (the default), but can be set
@@ -707,21 +762,20 @@ class Database(object):
           this table.
 
         Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key columns and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
+            table: The table to upsert into
+            keyvalues: The unique key columns and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
+            desc: description of the transaction, for logging and metrics
+            lock: True to lock the table when doing the upsert.
         Returns:
-            Deferred(None or bool): Native upserts always return None. Emulated
-            upserts return True if a new entry was created, False if an existing
-            one was updated.
+            Native upserts always return None. Emulated upserts return True if a
+            new entry was created, False if an existing one was updated.
         """
         attempts = 0
         while True:
             try:
-                result = yield self.runInteraction(
+                return await self.runInteraction(
                     desc,
                     self.simple_upsert_txn,
                     table,
@@ -730,7 +784,6 @@ class Database(object):
                     insertion_values,
                     lock=lock,
                 )
-                return result
             except self.engine.module.IntegrityError as e:
                 attempts += 1
                 if attempts >= 5:
@@ -744,29 +797,34 @@ class Database(object):
                 )
 
     def simple_upsert_txn(
-        self, txn, table, keyvalues, values, insertion_values={}, lock=True
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+        lock: bool = True,
+    ) -> Optional[bool]:
         """
         Pick the UPSERT method which works best on the platform. Either the
         native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
 
         Args:
             txn: The transaction to use.
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
+            table: The table to upsert into
+            keyvalues: The unique key tables and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
+            lock: True to lock the table when doing the upsert.
         Returns:
-            None or bool: Native upserts always return None. Emulated
-            upserts return True if a new entry was created, False if an existing
-            one was updated.
+            Native upserts always return None. Emulated upserts return True if a
+            new entry was created, False if an existing one was updated.
         """
         if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
-            return self.simple_upsert_txn_native_upsert(
+            self.simple_upsert_txn_native_upsert(
                 txn, table, keyvalues, values, insertion_values=insertion_values
             )
+            return None
         else:
             return self.simple_upsert_txn_emulated(
                 txn,
@@ -778,18 +836,23 @@ class Database(object):
             )
 
     def simple_upsert_txn_emulated(
-        self, txn, table, keyvalues, values, insertion_values={}, lock=True
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+        lock: bool = True,
+    ) -> bool:
         """
         Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
+            table: The table to upsert into
+            keyvalues: The unique key tables and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
+            lock: True to lock the table when doing the upsert.
         Returns:
-            bool: Return True if a new entry was created, False if an existing
+            Returns True if a new entry was created, False if an existing
             one was updated.
         """
         # We need to lock the table :(, unless we're *really* careful
@@ -847,19 +910,21 @@ class Database(object):
         return True
 
     def simple_upsert_txn_native_upsert(
-        self, txn, table, keyvalues, values, insertion_values={}
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+    ) -> None:
         """
         Use the native UPSERT functionality in recent PostgreSQL versions.
 
         Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-        Returns:
-            None
+            table: The table to upsert into
+            keyvalues: The unique key tables and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
         """
         allvalues = {}  # type: Dict[str, Any]
         allvalues.update(keyvalues)
@@ -989,41 +1054,93 @@ class Database(object):
 
         return txn.execute_batch(sql, args)
 
-    def simple_select_one(
-        self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
-    ):
+    @overload
+    async def simple_select_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: Literal[False] = False,
+        desc: str = "simple_select_one",
+    ) -> Dict[str, Any]:
+        ...
+
+    @overload
+    async def simple_select_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: Literal[True] = True,
+        desc: str = "simple_select_one",
+    ) -> Optional[Dict[str, Any]]:
+        ...
+
+    async def simple_select_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: bool = False,
+        desc: str = "simple_select_one",
+    ) -> Optional[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning multiple columns from it.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            retcols : list of strings giving the names of the columns to return
-
-            allow_none : If true, return None instead of failing if the SELECT
-              statement returns no rows
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            retcols: list of strings giving the names of the columns to return
+            allow_none: If true, return None instead of failing if the SELECT
+                statement returns no rows
+            desc: description of the transaction, for logging and metrics
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
         )
 
-    def simple_select_one_onecol(
+    @overload
+    async def simple_select_one_onecol(
         self,
-        table,
-        keyvalues,
-        retcol,
-        allow_none=False,
-        desc="simple_select_one_onecol",
-    ):
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: str,
+        allow_none: Literal[False] = False,
+        desc: str = "simple_select_one_onecol",
+    ) -> Any:
+        ...
+
+    @overload
+    async def simple_select_one_onecol(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: str,
+        allow_none: Literal[True] = True,
+        desc: str = "simple_select_one_onecol",
+    ) -> Optional[Any]:
+        ...
+
+    async def simple_select_one_onecol(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: str,
+        allow_none: bool = False,
+        desc: str = "simple_select_one_onecol",
+    ) -> Optional[Any]:
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning a single column from it.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            retcol : string giving the name of the column to return
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            retcol: string giving the name of the column to return
+            allow_none: If true, return None instead of failing if the SELECT
+                statement returns no rows
+            desc: description of the transaction, for logging and metrics
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc,
             self.simple_select_one_onecol_txn,
             table,
@@ -1032,10 +1149,39 @@ class Database(object):
             allow_none=allow_none,
         )
 
+    @overload
     @classmethod
     def simple_select_one_onecol_txn(
-        cls, txn, table, keyvalues, retcol, allow_none=False
-    ):
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: str,
+        allow_none: Literal[False] = False,
+    ) -> Any:
+        ...
+
+    @overload
+    @classmethod
+    def simple_select_one_onecol_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: str,
+        allow_none: Literal[True] = True,
+    ) -> Optional[Any]:
+        ...
+
+    @classmethod
+    def simple_select_one_onecol_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: str,
+        allow_none: bool = False,
+    ) -> Optional[Any]:
         ret = cls.simple_select_onecol_txn(
             txn, table=table, keyvalues=keyvalues, retcol=retcol
         )
@@ -1049,64 +1195,85 @@ class Database(object):
                 raise StoreError(404, "No row found")
 
     @staticmethod
-    def simple_select_onecol_txn(txn, table, keyvalues, retcol):
+    def simple_select_onecol_txn(
+        txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any], retcol: str,
+    ) -> List[Any]:
         sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
 
         if keyvalues:
-            sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+            sql += " WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
             txn.execute(sql, list(keyvalues.values()))
         else:
             txn.execute(sql)
 
         return [r[0] for r in txn]
 
-    def simple_select_onecol(
-        self, table, keyvalues, retcol, desc="simple_select_onecol"
-    ):
+    async def simple_select_onecol(
+        self,
+        table: str,
+        keyvalues: Optional[Dict[str, Any]],
+        retcol: str,
+        desc: str = "simple_select_onecol",
+    ) -> List[Any]:
         """Executes a SELECT query on the named table, which returns a list
         comprising of the values of the named column from the selected rows.
 
         Args:
-            table (str): table name
-            keyvalues (dict|None): column names and values to select the rows with
-            retcol (str): column whos value we wish to retrieve.
+            table: table name
+            keyvalues: column names and values to select the rows with
+            retcol: column whos value we wish to retrieve.
+            desc: description of the transaction, for logging and metrics
 
         Returns:
-            Deferred: Results in a list
+            Results in a list
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc, self.simple_select_onecol_txn, table, keyvalues, retcol
         )
 
-    def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
+    async def simple_select_list(
+        self,
+        table: str,
+        keyvalues: Optional[Dict[str, Any]],
+        retcols: Iterable[str],
+        desc: str = "simple_select_list",
+    ) -> List[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            table (str): the table name
-            keyvalues (dict[str, Any] | None):
+            table: the table name
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            retcols (iterable[str]): the names of the columns to return
+            retcols: the names of the columns to return
+            desc: description of the transaction, for logging and metrics
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
+            A list of dictionaries.
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc, self.simple_select_list_txn, table, keyvalues, retcols
         )
 
     @classmethod
-    def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+    def simple_select_list_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Optional[Dict[str, Any]],
+        retcols: Iterable[str],
+    ) -> List[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            txn : Transaction object
-            table (str): the table name
-            keyvalues (dict[str, T] | None):
+            txn: Transaction object
+            table: the table name
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            retcols (iterable[str]): the names of the columns to return
+            retcols: the names of the columns to return
         """
         if keyvalues:
             sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1121,28 +1288,29 @@ class Database(object):
 
         return cls.cursor_to_dict(txn)
 
-    @defer.inlineCallbacks
-    def simple_select_many_batch(
+    async def simple_select_many_batch(
         self,
-        table,
-        column,
-        iterable,
-        retcols,
-        keyvalues={},
-        desc="simple_select_many_batch",
-        batch_size=100,
-    ):
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        retcols: Iterable[str],
+        keyvalues: Dict[str, Any] = {},
+        desc: str = "simple_select_many_batch",
+        batch_size: int = 100,
+    ) -> List[Any]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
-        Filters rows by if value of `column` is in `iterable`.
+        Filters rows by whether the value of `column` is in `iterable`.
 
         Args:
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-            retcols : list of strings giving the names of the columns to return
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            retcols: list of strings giving the names of the columns to return
+            keyvalues: dict of column names and values to select the rows with
+            desc: description of the transaction, for logging and metrics
+            batch_size: the number of rows for each select query
         """
         results = []  # type: List[Dict[str, Any]]
 
@@ -1156,7 +1324,7 @@ class Database(object):
             it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
         ]
         for chunk in chunks:
-            rows = yield self.runInteraction(
+            rows = await self.runInteraction(
                 desc,
                 self.simple_select_many_txn,
                 table,
@@ -1171,19 +1339,27 @@ class Database(object):
         return results
 
     @classmethod
-    def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+    def simple_select_many_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+    ) -> List[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
-        Filters rows by if value of `column` is in `iterable`.
+        Filters rows by whether the value of `column` is in `iterable`.
 
         Args:
-            txn : Transaction object
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-            retcols : list of strings giving the names of the columns to return
+            txn: Transaction object
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            keyvalues: dict of column names and values to select the rows with
+            retcols: list of strings giving the names of the columns to return
         """
         if not iterable:
             return []
@@ -1191,7 +1367,7 @@ class Database(object):
         clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
         clauses = [clause]
 
-        for key, value in iteritems(keyvalues):
+        for key, value in keyvalues.items():
             clauses.append("%s = ?" % (key,))
             values.append(value)
 
@@ -1204,15 +1380,26 @@ class Database(object):
         txn.execute(sql, values)
         return cls.cursor_to_dict(txn)
 
-    def simple_update(self, table, keyvalues, updatevalues, desc):
-        return self.runInteraction(
+    async def simple_update(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+        desc: str,
+    ) -> int:
+        return await self.runInteraction(
             desc, self.simple_update_txn, table, keyvalues, updatevalues
         )
 
     @staticmethod
-    def simple_update_txn(txn, table, keyvalues, updatevalues):
+    def simple_update_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+    ) -> int:
         if keyvalues:
-            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in iterkeys(keyvalues))
+            where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
         else:
             where = ""
 
@@ -1226,32 +1413,34 @@ class Database(object):
 
         return txn.rowcount
 
-    def simple_update_one(
-        self, table, keyvalues, updatevalues, desc="simple_update_one"
-    ):
+    async def simple_update_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+        desc: str = "simple_update_one",
+    ) -> None:
         """Executes an UPDATE query on the named table, setting new values for
         columns in a row matching the key values.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            updatevalues : dict giving column names and values to update
-            retcols : optional list of column names to return
-
-        If present, retcols gives a list of column names on which to perform
-        a SELECT statement *before* performing the UPDATE statement. The values
-        of these will be returned in a dict.
-
-        These are performed within the same transaction, allowing an atomic
-        get-and-set.  This can be used to implement compare-and-set by putting
-        the update column in the 'keyvalues' dict as well.
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            updatevalues: dict giving column names and values to update
+            desc: description of the transaction, for logging and metrics
         """
-        return self.runInteraction(
+        await self.runInteraction(
             desc, self.simple_update_one_txn, table, keyvalues, updatevalues
         )
 
     @classmethod
-    def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+    def simple_update_one_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+    ) -> None:
         rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
 
         if rowcount == 0:
@@ -1259,8 +1448,18 @@ class Database(object):
         if rowcount > 1:
             raise StoreError(500, "More than one row matched (%s)" % (table,))
 
+    # Ideally we could use the overload decorator here to specify that the
+    # return type is only optional if allow_none is True, but this does not work
+    # when you call a static method from an instance.
+    # See https://github.com/python/mypy/issues/7781
     @staticmethod
-    def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+    def simple_select_one_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: bool = False,
+    ) -> Optional[Dict[str, Any]]:
         select_sql = "SELECT %s FROM %s WHERE %s" % (
             ", ".join(retcols),
             table,
@@ -1279,24 +1478,29 @@ class Database(object):
 
         return dict(zip(retcols, row))
 
-    def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
+    async def simple_delete_one(
+        self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
+    ) -> None:
         """Executes a DELETE query on the named table, expecting to delete a
         single row.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            desc: description of the transaction, for logging and metrics
         """
-        return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
+        await self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
 
     @staticmethod
-    def simple_delete_one_txn(txn, table, keyvalues):
+    def simple_delete_one_txn(
+        txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+    ) -> None:
         """Executes a DELETE query on the named table, expecting to delete a
         single row.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
         """
         sql = "DELETE FROM %s WHERE %s" % (
             table,
@@ -1309,11 +1513,38 @@ class Database(object):
         if txn.rowcount > 1:
             raise StoreError(500, "More than one row matched (%s)" % (table,))
 
-    def simple_delete(self, table, keyvalues, desc):
-        return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
+    async def simple_delete(
+        self, table: str, keyvalues: Dict[str, Any], desc: str
+    ) -> int:
+        """Executes a DELETE query on the named table.
+
+        Filters rows by the key-value pairs.
+
+        Args:
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            desc: description of the transaction, for logging and metrics
+
+        Returns:
+            The number of deleted rows.
+        """
+        return await self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
 
     @staticmethod
-    def simple_delete_txn(txn, table, keyvalues):
+    def simple_delete_txn(
+        txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+    ) -> int:
+        """Executes a DELETE query on the named table.
+
+        Filters rows by the key-value pairs.
+
+        Args:
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+
+        Returns:
+            The number of deleted rows.
+        """
         sql = "DELETE FROM %s WHERE %s" % (
             table,
             " AND ".join("%s = ?" % (k,) for k in keyvalues),
@@ -1322,26 +1553,53 @@ class Database(object):
         txn.execute(sql, list(keyvalues.values()))
         return txn.rowcount
 
-    def simple_delete_many(self, table, column, iterable, keyvalues, desc):
-        return self.runInteraction(
+    async def simple_delete_many(
+        self,
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        keyvalues: Dict[str, Any],
+        desc: str,
+    ) -> int:
+        """Executes a DELETE query on the named table.
+
+        Filters rows by if value of `column` is in `iterable`.
+
+        Args:
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            keyvalues: dict of column names and values to select the rows with
+            desc: description of the transaction, for logging and metrics
+
+        Returns:
+            Number rows deleted
+        """
+        return await self.runInteraction(
             desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
         )
 
     @staticmethod
-    def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+    def simple_delete_many_txn(
+        txn: LoggingTransaction,
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        keyvalues: Dict[str, Any],
+    ) -> int:
         """Executes a DELETE query on the named table.
 
         Filters rows by if value of `column` is in `iterable`.
 
         Args:
-            txn : Transaction object
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
+            txn: Transaction object
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            keyvalues: dict of column names and values to select the rows with
 
         Returns:
-            int: Number rows deleted
+            Number rows deleted
         """
         if not iterable:
             return 0
@@ -1351,7 +1609,7 @@ class Database(object):
         clause, values = make_in_list_sql_clause(txn.database_engine, column, iterable)
         clauses = [clause]
 
-        for key, value in iteritems(keyvalues):
+        for key, value in keyvalues.items():
             clauses.append("%s = ?" % (key,))
             values.append(value)
 
@@ -1362,8 +1620,14 @@ class Database(object):
         return txn.rowcount
 
     def get_cache_dict(
-        self, db_conn, table, entity_column, stream_column, max_value, limit=100000
-    ):
+        self,
+        db_conn: Connection,
+        table: str,
+        entity_column: str,
+        stream_column: str,
+        max_value: int,
+        limit: int = 100000,
+    ) -> Tuple[Dict[Any, int], int]:
         # Fetch a mapping of room_id -> max stream position for "recent" rooms.
         # It doesn't really matter how many we get, the StreamChangeCache will
         # do the right thing to ensure it respects the max size of cache.
@@ -1388,71 +1652,25 @@ class Database(object):
         txn.close()
 
         if cache:
-            min_val = min(itervalues(cache))
+            min_val = min(cache.values())
         else:
             min_val = max_value
 
         return cache, min_val
 
-    def simple_select_list_paginate(
-        self,
-        table,
-        orderby,
-        start,
-        limit,
-        retcols,
-        filters=None,
-        keyvalues=None,
-        order_direction="ASC",
-        desc="simple_select_list_paginate",
-    ):
-        """
-        Executes a SELECT query on the named table with start and limit,
-        of row numbers, which may return zero or number of rows from start to limit,
-        returning the result as a list of dicts.
-
-        Args:
-            table (str): the table name
-            filters (dict[str, T] | None):
-                column names and values to filter the rows with, or None to not
-                apply a WHERE ? LIKE ? clause.
-            keyvalues (dict[str, T] | None):
-                column names and values to select the rows with, or None to not
-                apply a WHERE clause.
-            orderby (str): Column to order the results by.
-            start (int): Index to begin the query at.
-            limit (int): Number of results to return.
-            retcols (iterable[str]): the names of the columns to return
-            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
-        Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
-        """
-        return self.runInteraction(
-            desc,
-            self.simple_select_list_paginate_txn,
-            table,
-            orderby,
-            start,
-            limit,
-            retcols,
-            filters=filters,
-            keyvalues=keyvalues,
-            order_direction=order_direction,
-        )
-
     @classmethod
     def simple_select_list_paginate_txn(
         cls,
-        txn,
-        table,
-        orderby,
-        start,
-        limit,
-        retcols,
-        filters=None,
-        keyvalues=None,
-        order_direction="ASC",
-    ):
+        txn: LoggingTransaction,
+        table: str,
+        orderby: str,
+        start: int,
+        limit: int,
+        retcols: Iterable[str],
+        filters: Optional[Dict[str, Any]] = None,
+        keyvalues: Optional[Dict[str, Any]] = None,
+        order_direction: str = "ASC",
+    ) -> List[Dict[str, Any]]:
         """
         Executes a SELECT query on the named table with start and limit,
         of row numbers, which may return zero or number of rows from start to limit,
@@ -1463,21 +1681,22 @@ class Database(object):
         using 'AND'.
 
         Args:
-            txn : Transaction object
-            table (str): the table name
-            orderby (str): Column to order the results by.
-            start (int): Index to begin the query at.
-            limit (int): Number of results to return.
-            retcols (iterable[str]): the names of the columns to return
-            filters (dict[str, T] | None):
+            txn: Transaction object
+            table: the table name
+            orderby: Column to order the results by.
+            start: Index to begin the query at.
+            limit: Number of results to return.
+            retcols: the names of the columns to return
+            filters:
                 column names and values to filter the rows with, or None to not
                 apply a WHERE ? LIKE ? clause.
-            keyvalues (dict[str, T] | None):
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+            order_direction: Whether the results should be ordered "ASC" or "DESC".
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
+            The result as a list of dictionaries.
         """
         if order_direction not in ["ASC", "DESC"]:
             raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
@@ -1503,51 +1722,65 @@ class Database(object):
 
         return cls.cursor_to_dict(txn)
 
-    def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
+    async def simple_search_list(
+        self,
+        table: str,
+        term: Optional[str],
+        col: str,
+        retcols: Iterable[str],
+        desc="simple_search_list",
+    ) -> Optional[List[Dict[str, Any]]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            table (str): the table name
-            term (str | None):
-                term for searching the table matched to a column.
-            col (str): column to query term should be matched to
-            retcols (iterable[str]): the names of the columns to return
+            table: the table name
+            term: term for searching the table matched to a column.
+            col: column to query term should be matched to
+            retcols: the names of the columns to return
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]] or None
+            A list of dictionaries or None.
         """
 
-        return self.runInteraction(
+        return await self.runInteraction(
             desc, self.simple_search_list_txn, table, term, col, retcols
         )
 
     @classmethod
-    def simple_search_list_txn(cls, txn, table, term, col, retcols):
+    def simple_search_list_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        term: Optional[str],
+        col: str,
+        retcols: Iterable[str],
+    ) -> Optional[List[Dict[str, Any]]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            txn : Transaction object
-            table (str): the table name
-            term (str | None):
-                term for searching the table matched to a column.
-            col (str): column to query term should be matched to
-            retcols (iterable[str]): the names of the columns to return
+            txn: Transaction object
+            table: the table name
+            term: term for searching the table matched to a column.
+            col: column to query term should be matched to
+            retcols: the names of the columns to return
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]] or None
+            None if no term is given, otherwise a list of dictionaries.
         """
         if term:
             sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
             termvalues = ["%%" + term + "%%"]
             txn.execute(sql, termvalues)
         else:
-            return 0
+            return None
 
         return cls.cursor_to_dict(txn)
 
 
 def make_in_list_sql_clause(
-    database_engine, column: str, iterable: Iterable
+    database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
 ) -> Tuple[str, list]:
     """Returns an SQL clause that checks the given column is in the iterable.