summary refs log tree commit diff
path: root/synapse/storage/database.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/database.py')
-rw-r--r--synapse/storage/database.py651
1 files changed, 404 insertions, 247 deletions
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index 4ada6f5563..181c3ec249 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -28,9 +28,12 @@ from typing import (
     Optional,
     Tuple,
     TypeVar,
+    Union,
+    overload,
 )
 
 from prometheus_client import Histogram
+from typing_extensions import Literal
 
 from twisted.enterprise import adbapi
 from twisted.internet import defer
@@ -125,7 +128,7 @@ class LoggingTransaction:
     method.
 
     Args:
-        txn: The database transcation object to wrap.
+        txn: The database transaction object to wrap.
         name: The name of this transactions for logging.
         database_engine
         after_callbacks: A list that callbacks will be appended to
@@ -160,7 +163,7 @@ class LoggingTransaction:
         self.after_callbacks = after_callbacks
         self.exception_callbacks = exception_callbacks
 
-    def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
+    def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any):
         """Call the given callback on the main twisted thread after the
         transaction has finished. Used to invalidate the caches on the
         correct thread.
@@ -171,7 +174,9 @@ class LoggingTransaction:
         assert self.after_callbacks is not None
         self.after_callbacks.append((callback, args, kwargs))
 
-    def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+    def call_on_exception(
+        self, callback: "Callable[..., None]", *args: Any, **kwargs: Any
+    ):
         # if self.exception_callbacks is None, that means that whatever constructed the
         # LoggingTransaction isn't expecting there to be any callbacks; assert that
         # is not the case.
@@ -195,7 +200,7 @@ class LoggingTransaction:
     def description(self) -> Any:
         return self.txn.description
 
-    def execute_batch(self, sql, args):
+    def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
         if isinstance(self.database_engine, PostgresEngine):
             from psycopg2.extras import execute_batch  # type: ignore
 
@@ -204,17 +209,17 @@ class LoggingTransaction:
             for val in args:
                 self.execute(sql, val)
 
-    def execute(self, sql: str, *args: Any):
+    def execute(self, sql: str, *args: Any) -> None:
         self._do_execute(self.txn.execute, sql, *args)
 
-    def executemany(self, sql: str, *args: Any):
+    def executemany(self, sql: str, *args: Any) -> None:
         self._do_execute(self.txn.executemany, sql, *args)
 
     def _make_sql_one_line(self, sql: str) -> str:
         "Strip newlines out of SQL so that the loggers in the DB are on one line"
         return " ".join(line.strip() for line in sql.splitlines() if line.strip())
 
-    def _do_execute(self, func, sql, *args):
+    def _do_execute(self, func, sql: str, *args: Any) -> None:
         sql = self._make_sql_one_line(sql)
 
         # TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -240,7 +245,7 @@ class LoggingTransaction:
             sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
             sql_query_timer.labels(sql.split()[0]).observe(secs)
 
-    def close(self):
+    def close(self) -> None:
         self.txn.close()
 
 
@@ -249,13 +254,13 @@ class PerformanceCounters(object):
         self.current_counters = {}
         self.previous_counters = {}
 
-    def update(self, key, duration_secs):
+    def update(self, key: str, duration_secs: float) -> None:
         count, cum_time = self.current_counters.get(key, (0, 0))
         count += 1
         cum_time += duration_secs
         self.current_counters[key] = (count, cum_time)
 
-    def interval(self, interval_duration_secs, limit=3):
+    def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
         counters = []
         for name, (count, cum_time) in self.current_counters.items():
             prev_count, prev_time = self.previous_counters.get(name, (0, 0))
@@ -279,6 +284,9 @@ class PerformanceCounters(object):
         return top_n_counters
 
 
+R = TypeVar("R")
+
+
 class DatabasePool(object):
     """Wraps a single physical database and connection pool.
 
@@ -327,13 +335,12 @@ class DatabasePool(object):
                 self._check_safe_to_upsert,
             )
 
-    def is_running(self):
+    def is_running(self) -> bool:
         """Is the database pool currently running
         """
         return self._db_pool.running
 
-    @defer.inlineCallbacks
-    def _check_safe_to_upsert(self):
+    async def _check_safe_to_upsert(self) -> None:
         """
         Is it safe to use native UPSERT?
 
@@ -342,7 +349,7 @@ class DatabasePool(object):
 
         If the background updates have not completed, wait 15 sec and check again.
         """
-        updates = yield self.simple_select_list(
+        updates = await self.simple_select_list(
             "background_updates",
             keyvalues=None,
             retcols=["update_name"],
@@ -364,7 +371,7 @@ class DatabasePool(object):
                 self._check_safe_to_upsert,
             )
 
-    def start_profiling(self):
+    def start_profiling(self) -> None:
         self._previous_loop_ts = monotonic_time()
 
         def loop():
@@ -388,8 +395,15 @@ class DatabasePool(object):
         self._clock.looping_call(loop, 10000)
 
     def new_transaction(
-        self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
-    ):
+        self,
+        conn: Connection,
+        desc: str,
+        after_callbacks: List[_CallbackListEntry],
+        exception_callbacks: List[_CallbackListEntry],
+        func: "Callable[..., R]",
+        *args: Any,
+        **kwargs: Any
+    ) -> R:
         start = monotonic_time()
         txn_id = self._TXN_ID
 
@@ -517,14 +531,16 @@ class DatabasePool(object):
             logger.warning("Starting db txn '%s' from sentinel context", desc)
 
         try:
-            result = yield self.runWithConnection(
-                self.new_transaction,
-                desc,
-                after_callbacks,
-                exception_callbacks,
-                func,
-                *args,
-                **kwargs
+            result = yield defer.ensureDeferred(
+                self.runWithConnection(
+                    self.new_transaction,
+                    desc,
+                    after_callbacks,
+                    exception_callbacks,
+                    func,
+                    *args,
+                    **kwargs
+                )
             )
 
             for after_callback, after_args, after_kwargs in after_callbacks:
@@ -536,8 +552,9 @@ class DatabasePool(object):
 
         return result
 
-    @defer.inlineCallbacks
-    def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any):
+    async def runWithConnection(
+        self, func: "Callable[..., R]", *args: Any, **kwargs: Any
+    ) -> R:
         """Wraps the .runWithConnection() method on the underlying db_pool.
 
         Arguments:
@@ -548,7 +565,7 @@ class DatabasePool(object):
             kwargs: named args to pass to `func`
 
         Returns:
-            Deferred: The result of func
+            The result of func
         """
         parent_context = current_context()  # type: Optional[LoggingContextOrSentinel]
         if not parent_context:
@@ -571,18 +588,16 @@ class DatabasePool(object):
 
                 return func(conn, *args, **kwargs)
 
-        result = yield make_deferred_yieldable(
+        return await make_deferred_yieldable(
             self._db_pool.runWithConnection(inner_func, *args, **kwargs)
         )
 
-        return result
-
     @staticmethod
-    def cursor_to_dict(cursor):
+    def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
         """Converts a SQL cursor into an list of dicts.
 
         Args:
-            cursor : The DBAPI cursor which has executed a query.
+            cursor: The DBAPI cursor which has executed a query.
         Returns:
             A list of dicts where the key is the column header.
         """
@@ -590,7 +605,7 @@ class DatabasePool(object):
         results = [dict(zip(col_headers, row)) for row in cursor]
         return results
 
-    def execute(self, desc, decoder, query, *args):
+    def execute(self, desc: str, decoder: Callable, query: str, *args: Any):
         """Runs a single query for a result set.
 
         Args:
@@ -599,7 +614,7 @@ class DatabasePool(object):
             query - The query string to execute
             *args - Query args.
         Returns:
-            The result of decoder(results)
+            Deferred which results to the result of decoder(results)
         """
 
         def interaction(txn):
@@ -614,24 +629,28 @@ class DatabasePool(object):
     # "Simple" SQL API methods that operate on a single table with no JOINs,
     # no complex WHERE clauses, just a dict of values for columns.
 
-    @defer.inlineCallbacks
-    def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+    async def simple_insert(
+        self,
+        table: str,
+        values: Dict[str, Any],
+        or_ignore: bool = False,
+        desc: str = "simple_insert",
+    ) -> bool:
         """Executes an INSERT query on the named table.
 
         Args:
-            table : string giving the table name
-            values : dict of new column names and values for them
-            or_ignore : bool stating whether an exception should be raised
+            table: string giving the table name
+            values: dict of new column names and values for them
+            or_ignore: bool stating whether an exception should be raised
                 when a conflicting row already exists. If True, False will be
                 returned by the function instead
-            desc : string giving a description of the transaction
+            desc: string giving a description of the transaction
 
         Returns:
-            bool: Whether the row was inserted or not. Only useful when
-            `or_ignore` is True
+             Whether the row was inserted or not. Only useful when `or_ignore` is True
         """
         try:
-            yield self.runInteraction(desc, self.simple_insert_txn, table, values)
+            await self.runInteraction(desc, self.simple_insert_txn, table, values)
         except self.engine.module.IntegrityError:
             # We have to do or_ignore flag at this layer, since we can't reuse
             # a cursor after we receive an error from the db.
@@ -641,7 +660,9 @@ class DatabasePool(object):
         return True
 
     @staticmethod
-    def simple_insert_txn(txn, table, values):
+    def simple_insert_txn(
+        txn: LoggingTransaction, table: str, values: Dict[str, Any]
+    ) -> None:
         keys, vals = zip(*values.items())
 
         sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -652,11 +673,15 @@ class DatabasePool(object):
 
         txn.execute(sql, vals)
 
-    def simple_insert_many(self, table, values, desc):
+    def simple_insert_many(
+        self, table: str, values: List[Dict[str, Any]], desc: str
+    ) -> defer.Deferred:
         return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
 
     @staticmethod
-    def simple_insert_many_txn(txn, table, values):
+    def simple_insert_many_txn(
+        txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
+    ) -> None:
         if not values:
             return
 
@@ -684,16 +709,15 @@ class DatabasePool(object):
 
         txn.executemany(sql, vals)
 
-    @defer.inlineCallbacks
-    def simple_upsert(
+    async def simple_upsert(
         self,
-        table,
-        keyvalues,
-        values,
-        insertion_values={},
-        desc="simple_upsert",
-        lock=True,
-    ):
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+        desc: str = "simple_upsert",
+        lock: bool = True,
+    ) -> Optional[bool]:
         """
 
         `lock` should generally be set to True (the default), but can be set
@@ -707,21 +731,19 @@ class DatabasePool(object):
           this table.
 
         Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key columns and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
+            table: The table to upsert into
+            keyvalues: The unique key columns and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
+            lock: True to lock the table when doing the upsert.
         Returns:
-            Deferred(None or bool): Native upserts always return None. Emulated
-            upserts return True if a new entry was created, False if an existing
-            one was updated.
+            Native upserts always return None. Emulated upserts return True if a
+            new entry was created, False if an existing one was updated.
         """
         attempts = 0
         while True:
             try:
-                result = yield self.runInteraction(
+                return await self.runInteraction(
                     desc,
                     self.simple_upsert_txn,
                     table,
@@ -730,7 +752,6 @@ class DatabasePool(object):
                     insertion_values,
                     lock=lock,
                 )
-                return result
             except self.engine.module.IntegrityError as e:
                 attempts += 1
                 if attempts >= 5:
@@ -744,29 +765,34 @@ class DatabasePool(object):
                 )
 
     def simple_upsert_txn(
-        self, txn, table, keyvalues, values, insertion_values={}, lock=True
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+        lock: bool = True,
+    ) -> Optional[bool]:
         """
         Pick the UPSERT method which works best on the platform. Either the
         native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
 
         Args:
             txn: The transaction to use.
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
+            table: The table to upsert into
+            keyvalues: The unique key tables and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
+            lock: True to lock the table when doing the upsert.
         Returns:
-            None or bool: Native upserts always return None. Emulated
-            upserts return True if a new entry was created, False if an existing
-            one was updated.
+            Native upserts always return None. Emulated upserts return True if a
+            new entry was created, False if an existing one was updated.
         """
         if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
-            return self.simple_upsert_txn_native_upsert(
+            self.simple_upsert_txn_native_upsert(
                 txn, table, keyvalues, values, insertion_values=insertion_values
             )
+            return None
         else:
             return self.simple_upsert_txn_emulated(
                 txn,
@@ -778,18 +804,23 @@ class DatabasePool(object):
             )
 
     def simple_upsert_txn_emulated(
-        self, txn, table, keyvalues, values, insertion_values={}, lock=True
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+        lock: bool = True,
+    ) -> bool:
         """
         Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
+            table: The table to upsert into
+            keyvalues: The unique key tables and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
+            lock: True to lock the table when doing the upsert.
         Returns:
-            bool: Return True if a new entry was created, False if an existing
+            Returns True if a new entry was created, False if an existing
             one was updated.
         """
         # We need to lock the table :(, unless we're *really* careful
@@ -847,19 +878,21 @@ class DatabasePool(object):
         return True
 
     def simple_upsert_txn_native_upsert(
-        self, txn, table, keyvalues, values, insertion_values={}
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+    ) -> None:
         """
         Use the native UPSERT functionality in recent PostgreSQL versions.
 
         Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-        Returns:
-            None
+            table: The table to upsert into
+            keyvalues: The unique key tables and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
         """
         allvalues = {}  # type: Dict[str, Any]
         allvalues.update(keyvalues)
@@ -989,41 +1022,70 @@ class DatabasePool(object):
 
         return txn.execute_batch(sql, args)
 
-    def simple_select_one(
-        self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
-    ):
+    @overload
+    async def simple_select_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: Literal[False] = False,
+        desc: str = "simple_select_one",
+    ) -> Dict[str, Any]:
+        ...
+
+    @overload
+    async def simple_select_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: Literal[True] = True,
+        desc: str = "simple_select_one",
+    ) -> Optional[Dict[str, Any]]:
+        ...
+
+    async def simple_select_one(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: bool = False,
+        desc: str = "simple_select_one",
+    ) -> Optional[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning multiple columns from it.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            retcols : list of strings giving the names of the columns to return
-
-            allow_none : If true, return None instead of failing if the SELECT
-              statement returns no rows
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            retcols: list of strings giving the names of the columns to return
+            allow_none: If true, return None instead of failing if the SELECT
+                statement returns no rows
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
         )
 
-    def simple_select_one_onecol(
+    async def simple_select_one_onecol(
         self,
-        table,
-        keyvalues,
-        retcol,
-        allow_none=False,
-        desc="simple_select_one_onecol",
-    ):
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: Iterable[str],
+        allow_none: bool = False,
+        desc: str = "simple_select_one_onecol",
+    ) -> Optional[Any]:
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning a single column from it.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            retcol : string giving the name of the column to return
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            retcol: string giving the name of the column to return
+            allow_none: If true, return None instead of failing if the SELECT
+                statement returns no rows
+            desc: description of the transaction, for logging and metrics
         """
-        return self.runInteraction(
+        return await self.runInteraction(
             desc,
             self.simple_select_one_onecol_txn,
             table,
@@ -1034,8 +1096,13 @@ class DatabasePool(object):
 
     @classmethod
     def simple_select_one_onecol_txn(
-        cls, txn, table, keyvalues, retcol, allow_none=False
-    ):
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: Iterable[str],
+        allow_none: bool = False,
+    ) -> Optional[Any]:
         ret = cls.simple_select_onecol_txn(
             txn, table=table, keyvalues=keyvalues, retcol=retcol
         )
@@ -1049,7 +1116,12 @@ class DatabasePool(object):
                 raise StoreError(404, "No row found")
 
     @staticmethod
-    def simple_select_onecol_txn(txn, table, keyvalues, retcol):
+    def simple_select_onecol_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: Iterable[str],
+    ) -> List[Any]:
         sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
 
         if keyvalues:
@@ -1061,15 +1133,19 @@ class DatabasePool(object):
         return [r[0] for r in txn]
 
     def simple_select_onecol(
-        self, table, keyvalues, retcol, desc="simple_select_onecol"
-    ):
+        self,
+        table: str,
+        keyvalues: Optional[Dict[str, Any]],
+        retcol: str,
+        desc: str = "simple_select_onecol",
+    ) -> defer.Deferred:
         """Executes a SELECT query on the named table, which returns a list
         comprising of the values of the named column from the selected rows.
 
         Args:
-            table (str): table name
-            keyvalues (dict|None): column names and values to select the rows with
-            retcol (str): column whos value we wish to retrieve.
+            table: table name
+            keyvalues: column names and values to select the rows with
+            retcol: column whos value we wish to retrieve.
 
         Returns:
             Deferred: Results in a list
@@ -1078,16 +1154,22 @@ class DatabasePool(object):
             desc, self.simple_select_onecol_txn, table, keyvalues, retcol
         )
 
-    def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
+    def simple_select_list(
+        self,
+        table: str,
+        keyvalues: Optional[Dict[str, Any]],
+        retcols: Iterable[str],
+        desc: str = "simple_select_list",
+    ) -> defer.Deferred:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            table (str): the table name
-            keyvalues (dict[str, Any] | None):
+            table: the table name
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            retcols (iterable[str]): the names of the columns to return
+            retcols: the names of the columns to return
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]]
         """
@@ -1096,17 +1178,23 @@ class DatabasePool(object):
         )
 
     @classmethod
-    def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+    def simple_select_list_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Optional[Dict[str, Any]],
+        retcols: Iterable[str],
+    ) -> List[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            txn : Transaction object
-            table (str): the table name
-            keyvalues (dict[str, T] | None):
+            txn: Transaction object
+            table: the table name
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            retcols (iterable[str]): the names of the columns to return
+            retcols: the names of the columns to return
         """
         if keyvalues:
             sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1121,28 +1209,27 @@ class DatabasePool(object):
 
         return cls.cursor_to_dict(txn)
 
-    @defer.inlineCallbacks
-    def simple_select_many_batch(
+    async def simple_select_many_batch(
         self,
-        table,
-        column,
-        iterable,
-        retcols,
-        keyvalues={},
-        desc="simple_select_many_batch",
-        batch_size=100,
-    ):
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        retcols: Iterable[str],
+        keyvalues: Dict[str, Any] = {},
+        desc: str = "simple_select_many_batch",
+        batch_size: int = 100,
+    ) -> List[Any]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Filters rows by if value of `column` is in `iterable`.
 
         Args:
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-            retcols : list of strings giving the names of the columns to return
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            keyvalues: dict of column names and values to select the rows with
+            retcols: list of strings giving the names of the columns to return
         """
         results = []  # type: List[Dict[str, Any]]
 
@@ -1156,7 +1243,7 @@ class DatabasePool(object):
             it_list[i : i + batch_size] for i in range(0, len(it_list), batch_size)
         ]
         for chunk in chunks:
-            rows = yield self.runInteraction(
+            rows = await self.runInteraction(
                 desc,
                 self.simple_select_many_txn,
                 table,
@@ -1171,19 +1258,27 @@ class DatabasePool(object):
         return results
 
     @classmethod
-    def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+    def simple_select_many_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+    ) -> List[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Filters rows by if value of `column` is in `iterable`.
 
         Args:
-            txn : Transaction object
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-            retcols : list of strings giving the names of the columns to return
+            txn: Transaction object
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            keyvalues: dict of column names and values to select the rows with
+            retcols: list of strings giving the names of the columns to return
         """
         if not iterable:
             return []
@@ -1204,13 +1299,24 @@ class DatabasePool(object):
         txn.execute(sql, values)
         return cls.cursor_to_dict(txn)
 
-    def simple_update(self, table, keyvalues, updatevalues, desc):
+    def simple_update(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+        desc: str,
+    ) -> defer.Deferred:
         return self.runInteraction(
             desc, self.simple_update_txn, table, keyvalues, updatevalues
         )
 
     @staticmethod
-    def simple_update_txn(txn, table, keyvalues, updatevalues):
+    def simple_update_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+    ) -> int:
         if keyvalues:
             where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
         else:
@@ -1227,31 +1333,32 @@ class DatabasePool(object):
         return txn.rowcount
 
     def simple_update_one(
-        self, table, keyvalues, updatevalues, desc="simple_update_one"
-    ):
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+        desc: str = "simple_update_one",
+    ) -> defer.Deferred:
         """Executes an UPDATE query on the named table, setting new values for
         columns in a row matching the key values.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            updatevalues : dict giving column names and values to update
-            retcols : optional list of column names to return
-
-        If present, retcols gives a list of column names on which to perform
-        a SELECT statement *before* performing the UPDATE statement. The values
-        of these will be returned in a dict.
-
-        These are performed within the same transaction, allowing an atomic
-        get-and-set.  This can be used to implement compare-and-set by putting
-        the update column in the 'keyvalues' dict as well.
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            updatevalues: dict giving column names and values to update
         """
         return self.runInteraction(
             desc, self.simple_update_one_txn, table, keyvalues, updatevalues
         )
 
     @classmethod
-    def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+    def simple_update_one_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+    ) -> None:
         rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
 
         if rowcount == 0:
@@ -1259,8 +1366,18 @@ class DatabasePool(object):
         if rowcount > 1:
             raise StoreError(500, "More than one row matched (%s)" % (table,))
 
+    # Ideally we could use the overload decorator here to specify that the
+    # return type is only optional if allow_none is True, but this does not work
+    # when you call a static method from an instance.
+    # See https://github.com/python/mypy/issues/7781
     @staticmethod
-    def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+    def simple_select_one_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: bool = False,
+    ) -> Optional[Dict[str, Any]]:
         select_sql = "SELECT %s FROM %s WHERE %s" % (
             ", ".join(retcols),
             table,
@@ -1279,24 +1396,28 @@ class DatabasePool(object):
 
         return dict(zip(retcols, row))
 
-    def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
+    def simple_delete_one(
+        self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
+    ) -> defer.Deferred:
         """Executes a DELETE query on the named table, expecting to delete a
         single row.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
         """
         return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
 
     @staticmethod
-    def simple_delete_one_txn(txn, table, keyvalues):
+    def simple_delete_one_txn(
+        txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+    ) -> None:
         """Executes a DELETE query on the named table, expecting to delete a
         single row.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
         """
         sql = "DELETE FROM %s WHERE %s" % (
             table,
@@ -1309,11 +1430,13 @@ class DatabasePool(object):
         if txn.rowcount > 1:
             raise StoreError(500, "More than one row matched (%s)" % (table,))
 
-    def simple_delete(self, table, keyvalues, desc):
+    def simple_delete(self, table: str, keyvalues: Dict[str, Any], desc: str):
         return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
 
     @staticmethod
-    def simple_delete_txn(txn, table, keyvalues):
+    def simple_delete_txn(
+        txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+    ) -> int:
         sql = "DELETE FROM %s WHERE %s" % (
             table,
             " AND ".join("%s = ?" % (k,) for k in keyvalues),
@@ -1322,26 +1445,39 @@ class DatabasePool(object):
         txn.execute(sql, list(keyvalues.values()))
         return txn.rowcount
 
-    def simple_delete_many(self, table, column, iterable, keyvalues, desc):
+    def simple_delete_many(
+        self,
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        keyvalues: Dict[str, Any],
+        desc: str,
+    ) -> defer.Deferred:
         return self.runInteraction(
             desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
         )
 
     @staticmethod
-    def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+    def simple_delete_many_txn(
+        txn: LoggingTransaction,
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        keyvalues: Dict[str, Any],
+    ) -> int:
         """Executes a DELETE query on the named table.
 
         Filters rows by if value of `column` is in `iterable`.
 
         Args:
-            txn : Transaction object
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
+            txn: Transaction object
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            keyvalues: dict of column names and values to select the rows with
 
         Returns:
-            int: Number rows deleted
+            Number rows deleted
         """
         if not iterable:
             return 0
@@ -1362,8 +1498,14 @@ class DatabasePool(object):
         return txn.rowcount
 
     def get_cache_dict(
-        self, db_conn, table, entity_column, stream_column, max_value, limit=100000
-    ):
+        self,
+        db_conn: Connection,
+        table: str,
+        entity_column: str,
+        stream_column: str,
+        max_value: int,
+        limit: int = 100000,
+    ) -> Tuple[Dict[Any, int], int]:
         # Fetch a mapping of room_id -> max stream position for "recent" rooms.
         # It doesn't really matter how many we get, the StreamChangeCache will
         # do the right thing to ensure it respects the max size of cache.
@@ -1396,34 +1538,34 @@ class DatabasePool(object):
 
     def simple_select_list_paginate(
         self,
-        table,
-        orderby,
-        start,
-        limit,
-        retcols,
-        filters=None,
-        keyvalues=None,
-        order_direction="ASC",
-        desc="simple_select_list_paginate",
-    ):
+        table: str,
+        orderby: str,
+        start: int,
+        limit: int,
+        retcols: Iterable[str],
+        filters: Optional[Dict[str, Any]] = None,
+        keyvalues: Optional[Dict[str, Any]] = None,
+        order_direction: str = "ASC",
+        desc: str = "simple_select_list_paginate",
+    ) -> defer.Deferred:
         """
         Executes a SELECT query on the named table with start and limit,
         of row numbers, which may return zero or number of rows from start to limit,
         returning the result as a list of dicts.
 
         Args:
-            table (str): the table name
-            filters (dict[str, T] | None):
+            table: the table name
+            orderby: Column to order the results by.
+            start: Index to begin the query at.
+            limit: Number of results to return.
+            retcols: the names of the columns to return
+            filters:
                 column names and values to filter the rows with, or None to not
                 apply a WHERE ? LIKE ? clause.
-            keyvalues (dict[str, T] | None):
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            orderby (str): Column to order the results by.
-            start (int): Index to begin the query at.
-            limit (int): Number of results to return.
-            retcols (iterable[str]): the names of the columns to return
-            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+            order_direction: Whether the results should be ordered "ASC" or "DESC".
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]]
         """
@@ -1443,16 +1585,16 @@ class DatabasePool(object):
     @classmethod
     def simple_select_list_paginate_txn(
         cls,
-        txn,
-        table,
-        orderby,
-        start,
-        limit,
-        retcols,
-        filters=None,
-        keyvalues=None,
-        order_direction="ASC",
-    ):
+        txn: LoggingTransaction,
+        table: str,
+        orderby: str,
+        start: int,
+        limit: int,
+        retcols: Iterable[str],
+        filters: Optional[Dict[str, Any]] = None,
+        keyvalues: Optional[Dict[str, Any]] = None,
+        order_direction: str = "ASC",
+    ) -> List[Dict[str, Any]]:
         """
         Executes a SELECT query on the named table with start and limit,
         of row numbers, which may return zero or number of rows from start to limit,
@@ -1463,21 +1605,22 @@ class DatabasePool(object):
         using 'AND'.
 
         Args:
-            txn : Transaction object
-            table (str): the table name
-            orderby (str): Column to order the results by.
-            start (int): Index to begin the query at.
-            limit (int): Number of results to return.
-            retcols (iterable[str]): the names of the columns to return
-            filters (dict[str, T] | None):
+            txn: Transaction object
+            table: the table name
+            orderby: Column to order the results by.
+            start: Index to begin the query at.
+            limit: Number of results to return.
+            retcols: the names of the columns to return
+            filters:
                 column names and values to filter the rows with, or None to not
                 apply a WHERE ? LIKE ? clause.
-            keyvalues (dict[str, T] | None):
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+            order_direction: Whether the results should be ordered "ASC" or "DESC".
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
+            The result as a list of dictionaries.
         """
         if order_direction not in ["ASC", "DESC"]:
             raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
@@ -1503,16 +1646,23 @@ class DatabasePool(object):
 
         return cls.cursor_to_dict(txn)
 
-    def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
+    def simple_search_list(
+        self,
+        table: str,
+        term: Optional[str],
+        col: str,
+        retcols: Iterable[str],
+        desc="simple_search_list",
+    ):
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            table (str): the table name
-            term (str | None):
-                term for searching the table matched to a column.
-            col (str): column to query term should be matched to
-            retcols (iterable[str]): the names of the columns to return
+            table: the table name
+            term: term for searching the table matched to a column.
+            col: column to query term should be matched to
+            retcols: the names of the columns to return
+
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]] or None
         """
@@ -1522,19 +1672,26 @@ class DatabasePool(object):
         )
 
     @classmethod
-    def simple_search_list_txn(cls, txn, table, term, col, retcols):
+    def simple_search_list_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        term: Optional[str],
+        col: str,
+        retcols: Iterable[str],
+    ) -> Union[List[Dict[str, Any]], int]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            txn : Transaction object
-            table (str): the table name
-            term (str | None):
-                term for searching the table matched to a column.
-            col (str): column to query term should be matched to
-            retcols (iterable[str]): the names of the columns to return
+            txn: Transaction object
+            table: the table name
+            term: term for searching the table matched to a column.
+            col: column to query term should be matched to
+            retcols: the names of the columns to return
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]] or None
+            0 if no term is given, otherwise a list of dictionaries.
         """
         if term:
             sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
@@ -1547,7 +1704,7 @@ class DatabasePool(object):
 
 
 def make_in_list_sql_clause(
-    database_engine, column: str, iterable: Iterable
+    database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
 ) -> Tuple[str, list]:
     """Returns an SQL clause that checks the given column is in the iterable.