summary refs log tree commit diff
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--changelog.d/8127.misc1
-rw-r--r--synapse/storage/database.py577
-rw-r--r--synapse/storage/databases/main/ui_auth.py11
3 files changed, 367 insertions, 222 deletions
diff --git a/changelog.d/8127.misc b/changelog.d/8127.misc
new file mode 100644
index 0000000000..cb557122aa
--- /dev/null
+++ b/changelog.d/8127.misc
@@ -0,0 +1 @@
+Add type hints to `synapse.storage.database`.
diff --git a/synapse/storage/database.py b/synapse/storage/database.py
index b9aef96b08..bc327e344e 100644
--- a/synapse/storage/database.py
+++ b/synapse/storage/database.py
@@ -28,6 +28,7 @@ from typing import (
     Optional,
     Tuple,
     TypeVar,
+    Union,
 )
 
 from prometheus_client import Histogram
@@ -125,7 +126,7 @@ class LoggingTransaction:
     method.
 
     Args:
-        txn: The database transcation object to wrap.
+        txn: The database transaction object to wrap.
         name: The name of this transactions for logging.
         database_engine
         after_callbacks: A list that callbacks will be appended to
@@ -160,7 +161,7 @@ class LoggingTransaction:
         self.after_callbacks = after_callbacks
         self.exception_callbacks = exception_callbacks
 
-    def call_after(self, callback: "Callable[..., None]", *args, **kwargs):
+    def call_after(self, callback: "Callable[..., None]", *args: Any, **kwargs: Any):
         """Call the given callback on the main twisted thread after the
         transaction has finished. Used to invalidate the caches on the
         correct thread.
@@ -171,7 +172,9 @@ class LoggingTransaction:
         assert self.after_callbacks is not None
         self.after_callbacks.append((callback, args, kwargs))
 
-    def call_on_exception(self, callback: "Callable[..., None]", *args, **kwargs):
+    def call_on_exception(
+        self, callback: "Callable[..., None]", *args: Any, **kwargs: Any
+    ):
         # if self.exception_callbacks is None, that means that whatever constructed the
         # LoggingTransaction isn't expecting there to be any callbacks; assert that
         # is not the case.
@@ -195,7 +198,7 @@ class LoggingTransaction:
     def description(self) -> Any:
         return self.txn.description
 
-    def execute_batch(self, sql, args):
+    def execute_batch(self, sql: str, args: Iterable[Iterable[Any]]) -> None:
         if isinstance(self.database_engine, PostgresEngine):
             from psycopg2.extras import execute_batch  # type: ignore
 
@@ -204,17 +207,17 @@ class LoggingTransaction:
             for val in args:
                 self.execute(sql, val)
 
-    def execute(self, sql: str, *args: Any):
+    def execute(self, sql: str, *args: Any) -> None:
         self._do_execute(self.txn.execute, sql, *args)
 
-    def executemany(self, sql: str, *args: Any):
+    def executemany(self, sql: str, *args: Any) -> None:
         self._do_execute(self.txn.executemany, sql, *args)
 
     def _make_sql_one_line(self, sql: str) -> str:
         "Strip newlines out of SQL so that the loggers in the DB are on one line"
         return " ".join(line.strip() for line in sql.splitlines() if line.strip())
 
-    def _do_execute(self, func, sql, *args):
+    def _do_execute(self, func, sql: str, *args: Any) -> None:
         sql = self._make_sql_one_line(sql)
 
         # TODO(paul): Maybe use 'info' and 'debug' for values?
@@ -240,7 +243,7 @@ class LoggingTransaction:
             sql_logger.debug("[SQL time] {%s} %f sec", self.name, secs)
             sql_query_timer.labels(sql.split()[0]).observe(secs)
 
-    def close(self):
+    def close(self) -> None:
         self.txn.close()
 
 
@@ -249,13 +252,13 @@ class PerformanceCounters(object):
         self.current_counters = {}
         self.previous_counters = {}
 
-    def update(self, key, duration_secs):
+    def update(self, key: str, duration_secs: float) -> None:
         count, cum_time = self.current_counters.get(key, (0, 0))
         count += 1
         cum_time += duration_secs
         self.current_counters[key] = (count, cum_time)
 
-    def interval(self, interval_duration_secs, limit=3):
+    def interval(self, interval_duration_secs: float, limit: int = 3) -> str:
         counters = []
         for name, (count, cum_time) in self.current_counters.items():
             prev_count, prev_time = self.previous_counters.get(name, (0, 0))
@@ -279,6 +282,9 @@ class PerformanceCounters(object):
         return top_n_counters
 
 
+R = TypeVar("R")
+
+
 class DatabasePool(object):
     """Wraps a single physical database and connection pool.
 
@@ -327,12 +333,12 @@ class DatabasePool(object):
                 self._check_safe_to_upsert,
             )
 
-    def is_running(self):
+    def is_running(self) -> bool:
         """Is the database pool currently running
         """
         return self._db_pool.running
 
-    async def _check_safe_to_upsert(self):
+    async def _check_safe_to_upsert(self) -> None:
         """
         Is it safe to use native UPSERT?
 
@@ -363,7 +369,7 @@ class DatabasePool(object):
                 self._check_safe_to_upsert,
             )
 
-    def start_profiling(self):
+    def start_profiling(self) -> None:
         self._previous_loop_ts = monotonic_time()
 
         def loop():
@@ -387,8 +393,15 @@ class DatabasePool(object):
         self._clock.looping_call(loop, 10000)
 
     def new_transaction(
-        self, conn, desc, after_callbacks, exception_callbacks, func, *args, **kwargs
-    ):
+        self,
+        conn: Connection,
+        desc: str,
+        after_callbacks: List[_CallbackListEntry],
+        exception_callbacks: List[_CallbackListEntry],
+        func: "Callable[..., R]",
+        *args: Any,
+        **kwargs: Any
+    ) -> R:
         start = monotonic_time()
         txn_id = self._TXN_ID
 
@@ -537,7 +550,9 @@ class DatabasePool(object):
 
         return result
 
-    async def runWithConnection(self, func: Callable, *args: Any, **kwargs: Any) -> Any:
+    async def runWithConnection(
+        self, func: "Callable[..., R]", *args: Any, **kwargs: Any
+    ) -> R:
         """Wraps the .runWithConnection() method on the underlying db_pool.
 
         Arguments:
@@ -576,11 +591,11 @@ class DatabasePool(object):
         )
 
     @staticmethod
-    def cursor_to_dict(cursor):
+    def cursor_to_dict(cursor: Cursor) -> List[Dict[str, Any]]:
         """Converts a SQL cursor into an list of dicts.
 
         Args:
-            cursor : The DBAPI cursor which has executed a query.
+            cursor: The DBAPI cursor which has executed a query.
         Returns:
             A list of dicts where the key is the column header.
         """
@@ -588,7 +603,7 @@ class DatabasePool(object):
         results = [dict(zip(col_headers, row)) for row in cursor]
         return results
 
-    def execute(self, desc, decoder, query, *args):
+    def execute(self, desc: str, decoder: Callable, query: str, *args: Any):
         """Runs a single query for a result set.
 
         Args:
@@ -597,7 +612,7 @@ class DatabasePool(object):
             query - The query string to execute
             *args - Query args.
         Returns:
-            The result of decoder(results)
+            Deferred which results to the result of decoder(results)
         """
 
         def interaction(txn):
@@ -612,20 +627,25 @@ class DatabasePool(object):
     # "Simple" SQL API methods that operate on a single table with no JOINs,
     # no complex WHERE clauses, just a dict of values for columns.
 
-    async def simple_insert(self, table, values, or_ignore=False, desc="simple_insert"):
+    async def simple_insert(
+        self,
+        table: str,
+        values: Dict[str, Any],
+        or_ignore: bool = False,
+        desc: str = "simple_insert",
+    ) -> bool:
         """Executes an INSERT query on the named table.
 
         Args:
-            table : string giving the table name
-            values : dict of new column names and values for them
-            or_ignore : bool stating whether an exception should be raised
+            table: string giving the table name
+            values: dict of new column names and values for them
+            or_ignore: bool stating whether an exception should be raised
                 when a conflicting row already exists. If True, False will be
                 returned by the function instead
-            desc : string giving a description of the transaction
+            desc: string giving a description of the transaction
 
         Returns:
-            bool: Whether the row was inserted or not. Only useful when
-            `or_ignore` is True
+             Whether the row was inserted or not. Only useful when `or_ignore` is True
         """
         try:
             await self.runInteraction(desc, self.simple_insert_txn, table, values)
@@ -638,7 +658,9 @@ class DatabasePool(object):
         return True
 
     @staticmethod
-    def simple_insert_txn(txn, table, values):
+    def simple_insert_txn(
+        txn: LoggingTransaction, table: str, values: Dict[str, Any]
+    ) -> None:
         keys, vals = zip(*values.items())
 
         sql = "INSERT INTO %s (%s) VALUES(%s)" % (
@@ -649,11 +671,15 @@ class DatabasePool(object):
 
         txn.execute(sql, vals)
 
-    def simple_insert_many(self, table, values, desc):
+    def simple_insert_many(
+        self, table: str, values: List[Dict[str, Any]], desc: str
+    ) -> defer.Deferred:
         return self.runInteraction(desc, self.simple_insert_many_txn, table, values)
 
     @staticmethod
-    def simple_insert_many_txn(txn, table, values):
+    def simple_insert_many_txn(
+        txn: LoggingTransaction, table: str, values: List[Dict[str, Any]]
+    ) -> None:
         if not values:
             return
 
@@ -683,13 +709,13 @@ class DatabasePool(object):
 
     async def simple_upsert(
         self,
-        table,
-        keyvalues,
-        values,
-        insertion_values={},
-        desc="simple_upsert",
-        lock=True,
-    ):
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+        desc: str = "simple_upsert",
+        lock: bool = True,
+    ) -> Optional[bool]:
         """
 
         `lock` should generally be set to True (the default), but can be set
@@ -703,16 +729,14 @@ class DatabasePool(object):
           this table.
 
         Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key columns and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
+            table: The table to upsert into
+            keyvalues: The unique key columns and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
+            lock: True to lock the table when doing the upsert.
         Returns:
-            None or bool: Native upserts always return None. Emulated
-            upserts return True if a new entry was created, False if an existing
-            one was updated.
+            Native upserts always return None. Emulated upserts return True if a
+            new entry was created, False if an existing one was updated.
         """
         attempts = 0
         while True:
@@ -739,29 +763,34 @@ class DatabasePool(object):
                 )
 
     def simple_upsert_txn(
-        self, txn, table, keyvalues, values, insertion_values={}, lock=True
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+        lock: bool = True,
+    ) -> Optional[bool]:
         """
         Pick the UPSERT method which works best on the platform. Either the
         native one (Pg9.5+, recent SQLites), or fall back to an emulated method.
 
         Args:
             txn: The transaction to use.
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
+            table: The table to upsert into
+            keyvalues: The unique key tables and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
+            lock: True to lock the table when doing the upsert.
         Returns:
-            None or bool: Native upserts always return None. Emulated
-            upserts return True if a new entry was created, False if an existing
-            one was updated.
+            Native upserts always return None. Emulated upserts return True if a
+            new entry was created, False if an existing one was updated.
         """
         if self.engine.can_native_upsert and table not in self._unsafe_to_upsert_tables:
-            return self.simple_upsert_txn_native_upsert(
+            self.simple_upsert_txn_native_upsert(
                 txn, table, keyvalues, values, insertion_values=insertion_values
             )
+            return None
         else:
             return self.simple_upsert_txn_emulated(
                 txn,
@@ -773,18 +802,23 @@ class DatabasePool(object):
             )
 
     def simple_upsert_txn_emulated(
-        self, txn, table, keyvalues, values, insertion_values={}, lock=True
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+        lock: bool = True,
+    ) -> bool:
         """
         Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-            lock (bool): True to lock the table when doing the upsert.
+            table: The table to upsert into
+            keyvalues: The unique key tables and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
+            lock: True to lock the table when doing the upsert.
         Returns:
-            bool: Return True if a new entry was created, False if an existing
+            Returns True if a new entry was created, False if an existing
             one was updated.
         """
         # We need to lock the table :(, unless we're *really* careful
@@ -842,19 +876,21 @@ class DatabasePool(object):
         return True
 
     def simple_upsert_txn_native_upsert(
-        self, txn, table, keyvalues, values, insertion_values={}
-    ):
+        self,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        values: Dict[str, Any],
+        insertion_values: Dict[str, Any] = {},
+    ) -> None:
         """
         Use the native UPSERT functionality in recent PostgreSQL versions.
 
         Args:
-            table (str): The table to upsert into
-            keyvalues (dict): The unique key tables and their new values
-            values (dict): The nonunique columns and their new values
-            insertion_values (dict): additional key/values to use only when
-                inserting
-        Returns:
-            None
+            table: The table to upsert into
+            keyvalues: The unique key tables and their new values
+            values: The nonunique columns and their new values
+            insertion_values: additional key/values to use only when inserting
         """
         allvalues = {}  # type: Dict[str, Any]
         allvalues.update(keyvalues)
@@ -985,18 +1021,22 @@ class DatabasePool(object):
         return txn.execute_batch(sql, args)
 
     def simple_select_one(
-        self, table, keyvalues, retcols, allow_none=False, desc="simple_select_one"
-    ):
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: bool = False,
+        desc: str = "simple_select_one",
+    ) -> defer.Deferred:
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning multiple columns from it.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            retcols : list of strings giving the names of the columns to return
-
-            allow_none : If true, return None instead of failing if the SELECT
-              statement returns no rows
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            retcols: list of strings giving the names of the columns to return
+            allow_none: If true, return None instead of failing if the SELECT
+                statement returns no rows
         """
         return self.runInteraction(
             desc, self.simple_select_one_txn, table, keyvalues, retcols, allow_none
@@ -1004,19 +1044,22 @@ class DatabasePool(object):
 
     def simple_select_one_onecol(
         self,
-        table,
-        keyvalues,
-        retcol,
-        allow_none=False,
-        desc="simple_select_one_onecol",
-    ):
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: Iterable[str],
+        allow_none: bool = False,
+        desc: str = "simple_select_one_onecol",
+    ) -> defer.Deferred:
         """Executes a SELECT query on the named table, which is expected to
         return a single row, returning a single column from it.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            retcol : string giving the name of the column to return
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            retcol: string giving the name of the column to return
+            allow_none: If true, return None instead of failing if the SELECT
+                statement returns no rows
+            desc: description of the transaction, for logging and metrics
         """
         return self.runInteraction(
             desc,
@@ -1029,8 +1072,13 @@ class DatabasePool(object):
 
     @classmethod
     def simple_select_one_onecol_txn(
-        cls, txn, table, keyvalues, retcol, allow_none=False
-    ):
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: Iterable[str],
+        allow_none: bool = False,
+    ) -> Optional[Any]:
         ret = cls.simple_select_onecol_txn(
             txn, table=table, keyvalues=keyvalues, retcol=retcol
         )
@@ -1044,7 +1092,12 @@ class DatabasePool(object):
                 raise StoreError(404, "No row found")
 
     @staticmethod
-    def simple_select_onecol_txn(txn, table, keyvalues, retcol):
+    def simple_select_onecol_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcol: Iterable[str],
+    ) -> List[Any]:
         sql = ("SELECT %(retcol)s FROM %(table)s") % {"retcol": retcol, "table": table}
 
         if keyvalues:
@@ -1056,15 +1109,19 @@ class DatabasePool(object):
         return [r[0] for r in txn]
 
     def simple_select_onecol(
-        self, table, keyvalues, retcol, desc="simple_select_onecol"
-    ):
+        self,
+        table: str,
+        keyvalues: Optional[Dict[str, Any]],
+        retcol: str,
+        desc: str = "simple_select_onecol",
+    ) -> defer.Deferred:
         """Executes a SELECT query on the named table, which returns a list
         comprising of the values of the named column from the selected rows.
 
         Args:
-            table (str): table name
-            keyvalues (dict|None): column names and values to select the rows with
-            retcol (str): column whos value we wish to retrieve.
+            table: table name
+            keyvalues: column names and values to select the rows with
+            retcol: column whos value we wish to retrieve.
 
         Returns:
             Deferred: Results in a list
@@ -1073,16 +1130,22 @@ class DatabasePool(object):
             desc, self.simple_select_onecol_txn, table, keyvalues, retcol
         )
 
-    def simple_select_list(self, table, keyvalues, retcols, desc="simple_select_list"):
+    def simple_select_list(
+        self,
+        table: str,
+        keyvalues: Optional[Dict[str, Any]],
+        retcols: Iterable[str],
+        desc: str = "simple_select_list",
+    ) -> defer.Deferred:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            table (str): the table name
-            keyvalues (dict[str, Any] | None):
+            table: the table name
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            retcols (iterable[str]): the names of the columns to return
+            retcols: the names of the columns to return
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]]
         """
@@ -1091,17 +1154,23 @@ class DatabasePool(object):
         )
 
     @classmethod
-    def simple_select_list_txn(cls, txn, table, keyvalues, retcols):
+    def simple_select_list_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Optional[Dict[str, Any]],
+        retcols: Iterable[str],
+    ) -> List[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            txn : Transaction object
-            table (str): the table name
-            keyvalues (dict[str, T] | None):
+            txn: Transaction object
+            table: the table name
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            retcols (iterable[str]): the names of the columns to return
+            retcols: the names of the columns to return
         """
         if keyvalues:
             sql = "SELECT %s FROM %s WHERE %s" % (
@@ -1118,25 +1187,25 @@ class DatabasePool(object):
 
     async def simple_select_many_batch(
         self,
-        table,
-        column,
-        iterable,
-        retcols,
-        keyvalues={},
-        desc="simple_select_many_batch",
-        batch_size=100,
-    ):
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        retcols: Iterable[str],
+        keyvalues: Dict[str, Any] = {},
+        desc: str = "simple_select_many_batch",
+        batch_size: int = 100,
+    ) -> List[Any]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Filters rows by if value of `column` is in `iterable`.
 
         Args:
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-            retcols : list of strings giving the names of the columns to return
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            keyvalues: dict of column names and values to select the rows with
+            retcols: list of strings giving the names of the columns to return
         """
         results = []  # type: List[Dict[str, Any]]
 
@@ -1165,19 +1234,27 @@ class DatabasePool(object):
         return results
 
     @classmethod
-    def simple_select_many_txn(cls, txn, table, column, iterable, keyvalues, retcols):
+    def simple_select_many_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+    ) -> List[Dict[str, Any]]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Filters rows by if value of `column` is in `iterable`.
 
         Args:
-            txn : Transaction object
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
-            retcols : list of strings giving the names of the columns to return
+            txn: Transaction object
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            keyvalues: dict of column names and values to select the rows with
+            retcols: list of strings giving the names of the columns to return
         """
         if not iterable:
             return []
@@ -1198,13 +1275,24 @@ class DatabasePool(object):
         txn.execute(sql, values)
         return cls.cursor_to_dict(txn)
 
-    def simple_update(self, table, keyvalues, updatevalues, desc):
+    def simple_update(
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+        desc: str,
+    ) -> defer.Deferred:
         return self.runInteraction(
             desc, self.simple_update_txn, table, keyvalues, updatevalues
         )
 
     @staticmethod
-    def simple_update_txn(txn, table, keyvalues, updatevalues):
+    def simple_update_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+    ) -> int:
         if keyvalues:
             where = "WHERE %s" % " AND ".join("%s = ?" % k for k in keyvalues.keys())
         else:
@@ -1221,31 +1309,32 @@ class DatabasePool(object):
         return txn.rowcount
 
     def simple_update_one(
-        self, table, keyvalues, updatevalues, desc="simple_update_one"
-    ):
+        self,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+        desc: str = "simple_update_one",
+    ) -> defer.Deferred:
         """Executes an UPDATE query on the named table, setting new values for
         columns in a row matching the key values.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
-            updatevalues : dict giving column names and values to update
-            retcols : optional list of column names to return
-
-        If present, retcols gives a list of column names on which to perform
-        a SELECT statement *before* performing the UPDATE statement. The values
-        of these will be returned in a dict.
-
-        These are performed within the same transaction, allowing an atomic
-        get-and-set.  This can be used to implement compare-and-set by putting
-        the update column in the 'keyvalues' dict as well.
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
+            updatevalues: dict giving column names and values to update
         """
         return self.runInteraction(
             desc, self.simple_update_one_txn, table, keyvalues, updatevalues
         )
 
     @classmethod
-    def simple_update_one_txn(cls, txn, table, keyvalues, updatevalues):
+    def simple_update_one_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        updatevalues: Dict[str, Any],
+    ) -> None:
         rowcount = cls.simple_update_txn(txn, table, keyvalues, updatevalues)
 
         if rowcount == 0:
@@ -1253,8 +1342,18 @@ class DatabasePool(object):
         if rowcount > 1:
             raise StoreError(500, "More than one row matched (%s)" % (table,))
 
+    # Ideally we could use the overload decorator here to specify that the
+    # return type is only optional if allow_none is True, but this does not work
+    # when you call a static method from an instance.
+    # See https://github.com/python/mypy/issues/7781
     @staticmethod
-    def simple_select_one_txn(txn, table, keyvalues, retcols, allow_none=False):
+    def simple_select_one_txn(
+        txn: LoggingTransaction,
+        table: str,
+        keyvalues: Dict[str, Any],
+        retcols: Iterable[str],
+        allow_none: bool = False,
+    ) -> Optional[Dict[str, Any]]:
         select_sql = "SELECT %s FROM %s WHERE %s" % (
             ", ".join(retcols),
             table,
@@ -1273,24 +1372,28 @@ class DatabasePool(object):
 
         return dict(zip(retcols, row))
 
-    def simple_delete_one(self, table, keyvalues, desc="simple_delete_one"):
+    def simple_delete_one(
+        self, table: str, keyvalues: Dict[str, Any], desc: str = "simple_delete_one"
+    ) -> defer.Deferred:
         """Executes a DELETE query on the named table, expecting to delete a
         single row.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
         """
         return self.runInteraction(desc, self.simple_delete_one_txn, table, keyvalues)
 
     @staticmethod
-    def simple_delete_one_txn(txn, table, keyvalues):
+    def simple_delete_one_txn(
+        txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+    ) -> None:
         """Executes a DELETE query on the named table, expecting to delete a
         single row.
 
         Args:
-            table : string giving the table name
-            keyvalues : dict of column names and values to select the row with
+            table: string giving the table name
+            keyvalues: dict of column names and values to select the row with
         """
         sql = "DELETE FROM %s WHERE %s" % (
             table,
@@ -1303,11 +1406,13 @@ class DatabasePool(object):
         if txn.rowcount > 1:
             raise StoreError(500, "More than one row matched (%s)" % (table,))
 
-    def simple_delete(self, table, keyvalues, desc):
+    def simple_delete(self, table: str, keyvalues: Dict[str, Any], desc: str):
         return self.runInteraction(desc, self.simple_delete_txn, table, keyvalues)
 
     @staticmethod
-    def simple_delete_txn(txn, table, keyvalues):
+    def simple_delete_txn(
+        txn: LoggingTransaction, table: str, keyvalues: Dict[str, Any]
+    ) -> int:
         sql = "DELETE FROM %s WHERE %s" % (
             table,
             " AND ".join("%s = ?" % (k,) for k in keyvalues),
@@ -1316,26 +1421,39 @@ class DatabasePool(object):
         txn.execute(sql, list(keyvalues.values()))
         return txn.rowcount
 
-    def simple_delete_many(self, table, column, iterable, keyvalues, desc):
+    def simple_delete_many(
+        self,
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        keyvalues: Dict[str, Any],
+        desc: str,
+    ) -> defer.Deferred:
         return self.runInteraction(
             desc, self.simple_delete_many_txn, table, column, iterable, keyvalues
         )
 
     @staticmethod
-    def simple_delete_many_txn(txn, table, column, iterable, keyvalues):
+    def simple_delete_many_txn(
+        txn: LoggingTransaction,
+        table: str,
+        column: str,
+        iterable: Iterable[Any],
+        keyvalues: Dict[str, Any],
+    ) -> int:
         """Executes a DELETE query on the named table.
 
         Filters rows by if value of `column` is in `iterable`.
 
         Args:
-            txn : Transaction object
-            table : string giving the table name
-            column : column name to test for inclusion against `iterable`
-            iterable : list
-            keyvalues : dict of column names and values to select the rows with
+            txn: Transaction object
+            table: string giving the table name
+            column: column name to test for inclusion against `iterable`
+            iterable: list
+            keyvalues: dict of column names and values to select the rows with
 
         Returns:
-            int: Number rows deleted
+            Number rows deleted
         """
         if not iterable:
             return 0
@@ -1356,8 +1474,14 @@ class DatabasePool(object):
         return txn.rowcount
 
     def get_cache_dict(
-        self, db_conn, table, entity_column, stream_column, max_value, limit=100000
-    ):
+        self,
+        db_conn: Connection,
+        table: str,
+        entity_column: str,
+        stream_column: str,
+        max_value: int,
+        limit: int = 100000,
+    ) -> Tuple[Dict[Any, int], int]:
         # Fetch a mapping of room_id -> max stream position for "recent" rooms.
         # It doesn't really matter how many we get, the StreamChangeCache will
         # do the right thing to ensure it respects the max size of cache.
@@ -1390,34 +1514,34 @@ class DatabasePool(object):
 
     def simple_select_list_paginate(
         self,
-        table,
-        orderby,
-        start,
-        limit,
-        retcols,
-        filters=None,
-        keyvalues=None,
-        order_direction="ASC",
-        desc="simple_select_list_paginate",
-    ):
+        table: str,
+        orderby: str,
+        start: int,
+        limit: int,
+        retcols: Iterable[str],
+        filters: Optional[Dict[str, Any]] = None,
+        keyvalues: Optional[Dict[str, Any]] = None,
+        order_direction: str = "ASC",
+        desc: str = "simple_select_list_paginate",
+    ) -> defer.Deferred:
         """
         Executes a SELECT query on the named table with start and limit,
         of row numbers, which may return zero or number of rows from start to limit,
         returning the result as a list of dicts.
 
         Args:
-            table (str): the table name
-            filters (dict[str, T] | None):
+            table: the table name
+            orderby: Column to order the results by.
+            start: Index to begin the query at.
+            limit: Number of results to return.
+            retcols: the names of the columns to return
+            filters:
                 column names and values to filter the rows with, or None to not
                 apply a WHERE ? LIKE ? clause.
-            keyvalues (dict[str, T] | None):
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            orderby (str): Column to order the results by.
-            start (int): Index to begin the query at.
-            limit (int): Number of results to return.
-            retcols (iterable[str]): the names of the columns to return
-            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+            order_direction: Whether the results should be ordered "ASC" or "DESC".
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]]
         """
@@ -1437,16 +1561,16 @@ class DatabasePool(object):
     @classmethod
     def simple_select_list_paginate_txn(
         cls,
-        txn,
-        table,
-        orderby,
-        start,
-        limit,
-        retcols,
-        filters=None,
-        keyvalues=None,
-        order_direction="ASC",
-    ):
+        txn: LoggingTransaction,
+        table: str,
+        orderby: str,
+        start: int,
+        limit: int,
+        retcols: Iterable[str],
+        filters: Optional[Dict[str, Any]] = None,
+        keyvalues: Optional[Dict[str, Any]] = None,
+        order_direction: str = "ASC",
+    ) -> List[Dict[str, Any]]:
         """
         Executes a SELECT query on the named table with start and limit,
         of row numbers, which may return zero or number of rows from start to limit,
@@ -1457,21 +1581,22 @@ class DatabasePool(object):
         using 'AND'.
 
         Args:
-            txn : Transaction object
-            table (str): the table name
-            orderby (str): Column to order the results by.
-            start (int): Index to begin the query at.
-            limit (int): Number of results to return.
-            retcols (iterable[str]): the names of the columns to return
-            filters (dict[str, T] | None):
+            txn: Transaction object
+            table: the table name
+            orderby: Column to order the results by.
+            start: Index to begin the query at.
+            limit: Number of results to return.
+            retcols: the names of the columns to return
+            filters:
                 column names and values to filter the rows with, or None to not
                 apply a WHERE ? LIKE ? clause.
-            keyvalues (dict[str, T] | None):
+            keyvalues:
                 column names and values to select the rows with, or None to not
                 apply a WHERE clause.
-            order_direction (str): Whether the results should be ordered "ASC" or "DESC".
+            order_direction: Whether the results should be ordered "ASC" or "DESC".
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]]
+            The result as a list of dictionaries.
         """
         if order_direction not in ["ASC", "DESC"]:
             raise ValueError("order_direction must be one of 'ASC' or 'DESC'.")
@@ -1497,16 +1622,23 @@ class DatabasePool(object):
 
         return cls.cursor_to_dict(txn)
 
-    def simple_search_list(self, table, term, col, retcols, desc="simple_search_list"):
+    def simple_search_list(
+        self,
+        table: str,
+        term: Optional[str],
+        col: str,
+        retcols: Iterable[str],
+        desc="simple_search_list",
+    ):
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            table (str): the table name
-            term (str | None):
-                term for searching the table matched to a column.
-            col (str): column to query term should be matched to
-            retcols (iterable[str]): the names of the columns to return
+            table: the table name
+            term: term for searching the table matched to a column.
+            col: column to query term should be matched to
+            retcols: the names of the columns to return
+
         Returns:
             defer.Deferred: resolves to list[dict[str, Any]] or None
         """
@@ -1516,19 +1648,26 @@ class DatabasePool(object):
         )
 
     @classmethod
-    def simple_search_list_txn(cls, txn, table, term, col, retcols):
+    def simple_search_list_txn(
+        cls,
+        txn: LoggingTransaction,
+        table: str,
+        term: Optional[str],
+        col: str,
+        retcols: Iterable[str],
+    ) -> Union[List[Dict[str, Any]], int]:
         """Executes a SELECT query on the named table, which may return zero or
         more rows, returning the result as a list of dicts.
 
         Args:
-            txn : Transaction object
-            table (str): the table name
-            term (str | None):
-                term for searching the table matched to a column.
-            col (str): column to query term should be matched to
-            retcols (iterable[str]): the names of the columns to return
+            txn: Transaction object
+            table: the table name
+            term: term for searching the table matched to a column.
+            col: column to query term should be matched to
+            retcols: the names of the columns to return
+
         Returns:
-            defer.Deferred: resolves to list[dict[str, Any]] or None
+            0 if no term is given, otherwise a list of dictionaries.
         """
         if term:
             sql = "SELECT %s FROM %s WHERE %s LIKE ?" % (", ".join(retcols), table, col)
@@ -1541,7 +1680,7 @@ class DatabasePool(object):
 
 
 def make_in_list_sql_clause(
-    database_engine, column: str, iterable: Iterable
+    database_engine: BaseDatabaseEngine, column: str, iterable: Iterable
 ) -> Tuple[str, list]:
     """Returns an SQL clause that checks the given column is in the iterable.
 
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 37276f73f8..d80d7da895 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -19,6 +19,7 @@ from canonicaljson import json
 
 from synapse.api.errors import StoreError
 from synapse.storage._base import SQLBaseStore, db_to_json
+from synapse.storage.database import LoggingTransaction
 from synapse.types import JsonDict
 from synapse.util import stringutils as stringutils
 
@@ -214,14 +215,16 @@ class UIAuthWorkerStore(SQLBaseStore):
             value,
         )
 
-    def _set_ui_auth_session_data_txn(self, txn, session_id: str, key: str, value: Any):
+    def _set_ui_auth_session_data_txn(
+        self, txn: LoggingTransaction, session_id: str, key: str, value: Any
+    ):
         # Get the current value.
         result = self.db_pool.simple_select_one_txn(
             txn,
             table="ui_auth_sessions",
             keyvalues={"session_id": session_id},
             retcols=("serverdict",),
-        )
+        )  # type: Dict[str, Any]  # type: ignore
 
         # Update it and add it back to the database.
         serverdict = db_to_json(result["serverdict"])
@@ -275,7 +278,9 @@ class UIAuthStore(UIAuthWorkerStore):
             expiration_time,
         )
 
-    def _delete_old_ui_auth_sessions_txn(self, txn, expiration_time: int):
+    def _delete_old_ui_auth_sessions_txn(
+        self, txn: LoggingTransaction, expiration_time: int
+    ):
         # Get the expired sessions.
         sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
         txn.execute(sql, [expiration_time])