summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/12423.misc1
-rw-r--r--synapse/handlers/account_validity.py4
-rw-r--r--synapse/storage/databases/main/appservice.py28
-rw-r--r--synapse/storage/databases/main/registration.py133
-rw-r--r--synapse/storage/databases/main/relations.py2
-rw-r--r--synapse/storage/databases/main/signatures.py2
-rw-r--r--synapse/storage/databases/main/state.py2
-rw-r--r--synapse/storage/databases/main/stream.py24
-rw-r--r--synapse/storage/databases/main/tags.py4
9 files changed, 123 insertions, 77 deletions
diff --git a/changelog.d/12423.misc b/changelog.d/12423.misc
new file mode 100644
index 0000000000..e793d08e5e
--- /dev/null
+++ b/changelog.d/12423.misc
@@ -0,0 +1 @@
+Add some type hints to datastore.
\ No newline at end of file
diff --git a/synapse/handlers/account_validity.py b/synapse/handlers/account_validity.py
index 9d0975f636..05a138410e 100644
--- a/synapse/handlers/account_validity.py
+++ b/synapse/handlers/account_validity.py
@@ -180,9 +180,9 @@ class AccountValidityHandler:
         expiring_users = await self.store.get_users_expiring_soon()
 
         if expiring_users:
-            for user in expiring_users:
+            for user_id, expiration_ts_ms in expiring_users:
                 await self._send_renewal_email(
-                    user_id=user["user_id"], expiration_ts=user["expiration_ts_ms"]
+                    user_id=user_id, expiration_ts=expiration_ts_ms
                 )
 
     async def send_renewal_email_to_user(self, user_id: str) -> None:
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index eb32c34a85..fa732edcca 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import re
-from typing import TYPE_CHECKING, List, Optional, Pattern, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple
 
 from synapse.appservice import (
     ApplicationService,
@@ -26,7 +26,11 @@ from synapse.appservice import (
 from synapse.config.appservice import load_appservices
 from synapse.events import EventBase
 from synapse.storage._base import db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.types import Cursor
@@ -92,7 +96,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
 
         super().__init__(database, db_conn, hs)
 
-    def get_app_services(self):
+    def get_app_services(self) -> List[ApplicationService]:
         return self.services_cache
 
     def get_if_app_services_interested_in_user(self, user_id: str) -> bool:
@@ -256,7 +260,7 @@ class ApplicationServiceTransactionWorkerStore(
             A new transaction.
         """
 
-        def _create_appservice_txn(txn):
+        def _create_appservice_txn(txn: LoggingTransaction) -> AppServiceTransaction:
             new_txn_id = self._as_txn_seq_gen.get_next_id_txn(txn)
 
             # Insert new txn into txn table
@@ -291,7 +295,7 @@ class ApplicationServiceTransactionWorkerStore(
             service: The application service which was sent this transaction.
         """
 
-        def _complete_appservice_txn(txn):
+        def _complete_appservice_txn(txn: LoggingTransaction) -> None:
             # Set current txn_id for AS to 'txn_id'
             self.db_pool.simple_upsert_txn(
                 txn,
@@ -322,7 +326,9 @@ class ApplicationServiceTransactionWorkerStore(
             An AppServiceTransaction or None.
         """
 
-        def _get_oldest_unsent_txn(txn):
+        def _get_oldest_unsent_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[Dict[str, Any]]:
             # Monotonically increasing txn ids, so just select the smallest
             # one in the txns table (we delete them when they are sent)
             txn.execute(
@@ -364,7 +370,7 @@ class ApplicationServiceTransactionWorkerStore(
         )
 
     async def set_appservice_last_pos(self, pos: int) -> None:
-        def set_appservice_last_pos_txn(txn):
+        def set_appservice_last_pos_txn(txn: LoggingTransaction) -> None:
             txn.execute(
                 "UPDATE appservice_stream_position SET stream_ordering = ?", (pos,)
             )
@@ -378,7 +384,9 @@ class ApplicationServiceTransactionWorkerStore(
     ) -> Tuple[int, List[EventBase]]:
         """Get all new events for an appservice"""
 
-        def get_new_events_for_appservice_txn(txn):
+        def get_new_events_for_appservice_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[int, List[str]]:
             sql = (
                 "SELECT e.stream_ordering, e.event_id"
                 " FROM events AS e"
@@ -416,7 +424,7 @@ class ApplicationServiceTransactionWorkerStore(
                 % (type,)
             )
 
-        def get_type_stream_id_for_appservice_txn(txn):
+        def get_type_stream_id_for_appservice_txn(txn: LoggingTransaction) -> int:
             stream_id_type = "%s_stream_id" % type
             txn.execute(
                 # We do NOT want to escape `stream_id_type`.
@@ -444,7 +452,7 @@ class ApplicationServiceTransactionWorkerStore(
                 % (stream_type,)
             )
 
-        def set_appservice_stream_type_pos_txn(txn):
+        def set_appservice_stream_type_pos_txn(txn: LoggingTransaction) -> None:
             stream_id_type = "%s_stream_id" % stream_type
             txn.execute(
                 "UPDATE application_services_state SET %s = ? WHERE as_id=?"
diff --git a/synapse/storage/databases/main/registration.py b/synapse/storage/databases/main/registration.py
index c7634c92fd..d43163c27c 100644
--- a/synapse/storage/databases/main/registration.py
+++ b/synapse/storage/databases/main/registration.py
@@ -34,7 +34,7 @@ from synapse.storage.databases.main.stats import StatsStore
 from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import IdGenerator
 from synapse.storage.util.sequence import build_sequence_generator
-from synapse.types import UserID, UserInfo
+from synapse.types import JsonDict, UserID, UserInfo
 from synapse.util.caches.descriptors import cached
 
 if TYPE_CHECKING:
@@ -79,7 +79,7 @@ class TokenLookupResult:
 
     # Make the token owner default to the user ID, which is the common case.
     @token_owner.default
-    def _default_token_owner(self):
+    def _default_token_owner(self) -> str:
         return self.user_id
 
 
@@ -299,7 +299,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 the account.
         """
 
-        def set_account_validity_for_user_txn(txn):
+        def set_account_validity_for_user_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_update_txn(
                 txn=txn,
                 table="account_validity",
@@ -385,23 +385,25 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             desc="get_renewal_token_for_user",
         )
 
-    async def get_users_expiring_soon(self) -> List[Dict[str, Any]]:
+    async def get_users_expiring_soon(self) -> List[Tuple[str, int]]:
         """Selects users whose account will expire in the [now, now + renew_at] time
         window (see configuration for account_validity for information on what renew_at
         refers to).
 
         Returns:
-            A list of dictionaries, each with a user ID and expiration time (in milliseconds).
+            A list of tuples, each with a user ID and expiration time (in milliseconds).
         """
 
-        def select_users_txn(txn, now_ms, renew_at):
+        def select_users_txn(
+            txn: LoggingTransaction, now_ms: int, renew_at: int
+        ) -> List[Tuple[str, int]]:
             sql = (
                 "SELECT user_id, expiration_ts_ms FROM account_validity"
                 " WHERE email_sent = ? AND (expiration_ts_ms - ?) <= ?"
             )
             values = [False, now_ms, renew_at]
             txn.execute(sql, values)
-            return self.db_pool.cursor_to_dict(txn)
+            return cast(List[Tuple[str, int]], txn.fetchall())
 
         return await self.db_pool.runInteraction(
             "get_users_expiring_soon",
@@ -466,7 +468,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             admin: true iff the user is to be a server admin, false otherwise.
         """
 
-        def set_server_admin_txn(txn):
+        def set_server_admin_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_update_one_txn(
                 txn, "users", {"name": user.to_string()}, {"admin": 1 if admin else 0}
             )
@@ -515,7 +517,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             user_type: type of the user or None for a user without a type.
         """
 
-        def set_user_type_txn(txn):
+        def set_user_type_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_update_one_txn(
                 txn, "users", {"name": user.to_string()}, {"user_type": user_type}
             )
@@ -525,7 +527,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         await self.db_pool.runInteraction("set_user_type", set_user_type_txn)
 
-    def _query_for_auth(self, txn, token: str) -> Optional[TokenLookupResult]:
+    def _query_for_auth(
+        self, txn: LoggingTransaction, token: str
+    ) -> Optional[TokenLookupResult]:
         sql = """
             SELECT users.name as user_id,
                 users.is_guest,
@@ -582,7 +586,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             "is_support_user", self.is_support_user_txn, user_id
         )
 
-    def is_real_user_txn(self, txn, user_id):
+    def is_real_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
         res = self.db_pool.simple_select_one_onecol_txn(
             txn=txn,
             table="users",
@@ -592,7 +596,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         )
         return res is None
 
-    def is_support_user_txn(self, txn, user_id):
+    def is_support_user_txn(self, txn: LoggingTransaction, user_id: str) -> bool:
         res = self.db_pool.simple_select_one_onecol_txn(
             txn=txn,
             table="users",
@@ -609,10 +613,11 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
              A mapping of user_id -> password_hash.
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> Dict[str, str]:
             sql = "SELECT name, password_hash FROM users WHERE lower(name) = lower(?)"
             txn.execute(sql, (user_id,))
-            return dict(txn)
+            result = cast(List[Tuple[str, str]], txn.fetchall())
+            return dict(result)
 
         return await self.db_pool.runInteraction("get_users_by_id_case_insensitive", f)
 
@@ -734,7 +739,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         def _replace_user_external_id_txn(
             txn: LoggingTransaction,
-        ):
+        ) -> None:
             _remove_user_external_ids_txn(txn, user_id)
 
             for auth_provider, external_id in record_external_ids:
@@ -790,10 +795,10 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         )
         return [(r["auth_provider"], r["external_id"]) for r in res]
 
-    async def count_all_users(self):
+    async def count_all_users(self) -> int:
         """Counts all users registered on the homeserver."""
 
-        def _count_users(txn):
+        def _count_users(txn: LoggingTransaction) -> int:
             txn.execute("SELECT COUNT(*) AS users FROM users")
             rows = self.db_pool.cursor_to_dict(txn)
             if rows:
@@ -810,7 +815,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         who registered on the homeserver in the past 24 hours
         """
 
-        def _count_daily_user_type(txn):
+        def _count_daily_user_type(txn: LoggingTransaction) -> Dict[str, int]:
             yesterday = int(self._clock.time()) - (60 * 60 * 24)
 
             sql = """
@@ -835,23 +840,23 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             "count_daily_user_type", _count_daily_user_type
         )
 
-    async def count_nonbridged_users(self):
-        def _count_users(txn):
+    async def count_nonbridged_users(self) -> int:
+        def _count_users(txn: LoggingTransaction) -> int:
             txn.execute(
                 """
                 SELECT COUNT(*) FROM users
                 WHERE appservice_id IS NULL
             """
             )
-            (count,) = txn.fetchone()
+            (count,) = cast(Tuple[int], txn.fetchone())
             return count
 
         return await self.db_pool.runInteraction("count_users", _count_users)
 
-    async def count_real_users(self):
+    async def count_real_users(self) -> int:
         """Counts all users without a special user_type registered on the homeserver."""
 
-        def _count_users(txn):
+        def _count_users(txn: LoggingTransaction) -> int:
             txn.execute("SELECT COUNT(*) AS users FROM users where user_type is null")
             rows = self.db_pool.cursor_to_dict(txn)
             if rows:
@@ -888,7 +893,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         return user_id
 
     def get_user_id_by_threepid_txn(
-        self, txn, medium: str, address: str
+        self, txn: LoggingTransaction, medium: str, address: str
     ) -> Optional[str]:
         """Returns user id from threepid
 
@@ -925,7 +930,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             {"user_id": user_id, "validated_at": validated_at, "added_at": added_at},
         )
 
-    async def user_get_threepids(self, user_id) -> List[Dict[str, Any]]:
+    async def user_get_threepids(self, user_id: str) -> List[Dict[str, Any]]:
         return await self.db_pool.simple_select_list(
             "user_threepids",
             {"user_id": user_id},
@@ -957,7 +962,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
     async def add_user_bound_threepid(
         self, user_id: str, medium: str, address: str, id_server: str
-    ):
+    ) -> None:
         """The server proxied a bind request to the given identity server on
         behalf of the given user. We need to remember this in case the user
         asks us to unbind the threepid.
@@ -1116,7 +1121,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
 
         assert address or sid
 
-        def get_threepid_validation_session_txn(txn):
+        def get_threepid_validation_session_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[Dict[str, Any]]:
             sql = """
                 SELECT address, session_id, medium, client_secret,
                 last_send_attempt, validated_at
@@ -1150,7 +1157,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             session_id: The ID of the session to delete
         """
 
-        def delete_threepid_session_txn(txn):
+        def delete_threepid_session_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_delete_txn(
                 txn,
                 table="threepid_validation_token",
@@ -1170,7 +1177,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     async def cull_expired_threepid_validation_tokens(self) -> None:
         """Remove threepid validation tokens with expiry dates that have passed"""
 
-        def cull_expired_threepid_validation_tokens_txn(txn, ts):
+        def cull_expired_threepid_validation_tokens_txn(
+            txn: LoggingTransaction, ts: int
+        ) -> None:
             sql = """
             DELETE FROM threepid_validation_token WHERE
             expires < ?
@@ -1184,13 +1193,13 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
         )
 
     @wrap_as_background_process("account_validity_set_expiration_dates")
-    async def _set_expiration_date_when_missing(self):
+    async def _set_expiration_date_when_missing(self) -> None:
         """
         Retrieves the list of registered users that don't have an expiration date, and
         adds an expiration date for each of them.
         """
 
-        def select_users_with_no_expiration_date_txn(txn):
+        def select_users_with_no_expiration_date_txn(txn: LoggingTransaction) -> None:
             """Retrieves the list of registered users with no expiration date from the
             database, filtering out deactivated users.
             """
@@ -1213,7 +1222,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             select_users_with_no_expiration_date_txn,
         )
 
-    def set_expiration_date_for_user_txn(self, txn, user_id, use_delta=False):
+    def set_expiration_date_for_user_txn(
+        self, txn: LoggingTransaction, user_id: str, use_delta: bool = False
+    ) -> None:
         """Sets an expiration date to the account with the given user ID.
 
         Args:
@@ -1344,7 +1355,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             token: The registration token pending use
         """
 
-        def _set_registration_token_pending_txn(txn):
+        def _set_registration_token_pending_txn(txn: LoggingTransaction) -> None:
             pending = self.db_pool.simple_select_one_onecol_txn(
                 txn,
                 "registration_tokens",
@@ -1358,7 +1369,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 updatevalues={"pending": pending + 1},
             )
 
-        return await self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "set_registration_token_pending", _set_registration_token_pending_txn
         )
 
@@ -1372,7 +1383,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             token: The registration token to be 'used'
         """
 
-        def _use_registration_token_txn(txn):
+        def _use_registration_token_txn(txn: LoggingTransaction) -> None:
             # Normally, res is Optional[Dict[str, Any]].
             # Override type because the return type is only optional if
             # allow_none is True, and we don't want mypy throwing errors
@@ -1398,7 +1409,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
                 },
             )
 
-        return await self.db_pool.runInteraction(
+        await self.db_pool.runInteraction(
             "use_registration_token", _use_registration_token_txn
         )
 
@@ -1416,7 +1427,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             A list of dicts, each containing details of a token.
         """
 
-        def select_registration_tokens_txn(txn, now: int, valid: Optional[bool]):
+        def select_registration_tokens_txn(
+            txn: LoggingTransaction, now: int, valid: Optional[bool]
+        ) -> List[Dict[str, Any]]:
             if valid is None:
                 # Return all tokens regardless of validity
                 txn.execute("SELECT * FROM registration_tokens")
@@ -1523,7 +1536,7 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             Whether the row was inserted or not.
         """
 
-        def _create_registration_token_txn(txn):
+        def _create_registration_token_txn(txn: LoggingTransaction) -> bool:
             row = self.db_pool.simple_select_one_txn(
                 txn,
                 "registration_tokens",
@@ -1570,7 +1583,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
             A dict with all info about the token, or None if token doesn't exist.
         """
 
-        def _update_registration_token_txn(txn):
+        def _update_registration_token_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[Dict[str, Any]]:
             try:
                 self.db_pool.simple_update_one_txn(
                     txn,
@@ -1651,7 +1666,9 @@ class RegistrationWorkerStore(CacheInvalidationWorkerStore):
     ) -> Optional[RefreshTokenLookupResult]:
         """Lookup a refresh token with hints about its validity."""
 
-        def _lookup_refresh_token_txn(txn) -> Optional[RefreshTokenLookupResult]:
+        def _lookup_refresh_token_txn(
+            txn: LoggingTransaction,
+        ) -> Optional[RefreshTokenLookupResult]:
             txn.execute(
                 """
                 SELECT
@@ -1807,14 +1824,18 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
             unique=False,
         )
 
-    async def _background_update_set_deactivated_flag(self, progress, batch_size):
+    async def _background_update_set_deactivated_flag(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Retrieves a list of all deactivated users and sets the 'deactivated' flag to 1
         for each of them.
         """
 
         last_user = progress.get("user_id", "")
 
-        def _background_update_set_deactivated_flag_txn(txn):
+        def _background_update_set_deactivated_flag_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[bool, int]:
             txn.execute(
                 """
                 SELECT
@@ -1886,7 +1907,9 @@ class RegistrationBackgroundUpdateStore(RegistrationWorkerStore):
             deactivated,
         )
 
-    def set_user_deactivated_status_txn(self, txn, user_id: str, deactivated: bool):
+    def set_user_deactivated_status_txn(
+        self, txn: LoggingTransaction, user_id: str, deactivated: bool
+    ) -> None:
         self.db_pool.simple_update_one_txn(
             txn=txn,
             table="users",
@@ -2005,7 +2028,9 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
         return next_id
 
-    def _set_device_for_access_token_txn(self, txn, token: str, device_id: str) -> str:
+    def _set_device_for_access_token_txn(
+        self, txn: LoggingTransaction, token: str, device_id: str
+    ) -> str:
         old_device_id = self.db_pool.simple_select_one_onecol_txn(
             txn, "access_tokens", {"token": token}, "device_id"
         )
@@ -2084,7 +2109,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
 
     def _register_user(
         self,
-        txn,
+        txn: LoggingTransaction,
         user_id: str,
         password_hash: Optional[str],
         was_guest: bool,
@@ -2094,7 +2119,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         admin: bool,
         user_type: Optional[str],
         shadow_banned: bool,
-    ):
+    ) -> None:
         user_id_obj = UserID.from_string(user_id)
 
         now = int(self._clock.time())
@@ -2181,7 +2206,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             pointless. Use flush_user separately.
         """
 
-        def user_set_password_hash_txn(txn):
+        def user_set_password_hash_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_update_one_txn(
                 txn, "users", {"name": user_id}, {"password_hash": password_hash}
             )
@@ -2204,7 +2229,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             StoreError(404) if user not found
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> None:
             self.db_pool.simple_update_one_txn(
                 txn,
                 table="users",
@@ -2229,7 +2254,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             StoreError(404) if user not found
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> None:
             self.db_pool.simple_update_one_txn(
                 txn,
                 table="users",
@@ -2259,7 +2284,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
             A tuple of (token, token id, device id) for each of the deleted tokens
         """
 
-        def f(txn):
+        def f(txn: LoggingTransaction) -> List[Tuple[str, int, Optional[str]]]:
             keyvalues = {"user_id": user_id}
             if device_id is not None:
                 keyvalues["device_id"] = device_id
@@ -2301,7 +2326,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         return await self.db_pool.runInteraction("user_delete_access_tokens", f)
 
     async def delete_access_token(self, access_token: str) -> None:
-        def f(txn):
+        def f(txn: LoggingTransaction) -> None:
             self.db_pool.simple_delete_one_txn(
                 txn, table="access_tokens", keyvalues={"token": access_token}
             )
@@ -2313,7 +2338,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         await self.db_pool.runInteraction("delete_access_token", f)
 
     async def delete_refresh_token(self, refresh_token: str) -> None:
-        def f(txn):
+        def f(txn: LoggingTransaction) -> None:
             self.db_pool.simple_delete_one_txn(
                 txn, table="refresh_tokens", keyvalues={"token": refresh_token}
             )
@@ -2353,7 +2378,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
         """
 
         # Insert everything into a transaction in order to run atomically
-        def validate_threepid_session_txn(txn):
+        def validate_threepid_session_txn(txn: LoggingTransaction) -> Optional[str]:
             row = self.db_pool.simple_select_one_txn(
                 txn,
                 table="threepid_validation_session",
@@ -2450,7 +2475,7 @@ class RegistrationStore(StatsStore, RegistrationBackgroundUpdateStore):
                 longer be valid
         """
 
-        def start_or_continue_validation_session_txn(txn):
+        def start_or_continue_validation_session_txn(txn: LoggingTransaction) -> None:
             # Create or update a validation session
             self.db_pool.simple_upsert_txn(
                 txn,
diff --git a/synapse/storage/databases/main/relations.py b/synapse/storage/databases/main/relations.py
index db929ef523..407158ceee 100644
--- a/synapse/storage/databases/main/relations.py
+++ b/synapse/storage/databases/main/relations.py
@@ -742,7 +742,7 @@ class RelationsWorkerStore(SQLBaseStore):
                 %s;
         """
 
-        def _get_if_events_have_relations(txn) -> List[str]:
+        def _get_if_events_have_relations(txn: LoggingTransaction) -> List[str]:
             clauses: List[str] = []
             clause, args = make_in_list_sql_clause(
                 txn.database_engine, "relates_to_id", parent_ids
diff --git a/synapse/storage/databases/main/signatures.py b/synapse/storage/databases/main/signatures.py
index 0518b8b910..95148fd227 100644
--- a/synapse/storage/databases/main/signatures.py
+++ b/synapse/storage/databases/main/signatures.py
@@ -26,7 +26,7 @@ from synapse.util.caches.descriptors import cached, cachedList
 
 class SignatureWorkerStore(EventsWorkerStore):
     @cached()
-    def get_event_reference_hash(self, event_id):
+    def get_event_reference_hash(self, event_id: str) -> Dict[str, Dict[str, bytes]]:
         # This is a dummy function to allow get_event_reference_hashes
         # to use its cache
         raise NotImplementedError()
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 4a461a0abb..ecdc1fdc4c 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -204,7 +204,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             The current state of the room.
         """
 
-        def _get_current_state_ids_txn(txn):
+        def _get_current_state_ids_txn(txn: LoggingTransaction) -> StateMap[str]:
             txn.execute(
                 """SELECT type, state_key, event_id FROM current_state_events
                 WHERE room_id = ?
diff --git a/synapse/storage/databases/main/stream.py b/synapse/storage/databases/main/stream.py
index 82e9ef02d2..6d45a8a9f6 100644
--- a/synapse/storage/databases/main/stream.py
+++ b/synapse/storage/databases/main/stream.py
@@ -36,7 +36,17 @@ what sort order was used:
 """
 
 import logging
-from typing import TYPE_CHECKING, Any, Collection, Dict, List, Optional, Set, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Collection,
+    Dict,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    cast,
+)
 
 import attr
 from frozendict import frozendict
@@ -732,7 +742,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             A tuple of (stream ordering, topological ordering, event_id)
         """
 
-        def _f(txn):
+        def _f(txn: LoggingTransaction) -> Optional[Tuple[int, int, str]]:
             sql = (
                 "SELECT stream_ordering, topological_ordering, event_id"
                 " FROM events"
@@ -742,7 +752,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
                 " LIMIT 1"
             )
             txn.execute(sql, (room_id, stream_ordering))
-            return txn.fetchone()
+            return cast(Optional[Tuple[int, int, str]], txn.fetchone())
 
         return await self.db_pool.runInteraction(
             "get_room_event_before_stream_ordering", _f
@@ -839,7 +849,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
     @staticmethod
     def _set_before_and_after(
         events: List[EventBase], rows: List[_EventDictReturn], topo_order: bool = True
-    ):
+    ) -> None:
         """Inserts ordering information to events' internal metadata from
         the DB rows.
 
@@ -985,7 +995,9 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
             the `current_id`).
         """
 
-        def get_all_new_events_stream_txn(txn):
+        def get_all_new_events_stream_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[int, List[str]]:
             sql = (
                 "SELECT e.stream_ordering, e.event_id"
                 " FROM events AS e"
@@ -1331,7 +1343,7 @@ class StreamWorkerStore(EventsWorkerStore, SQLBaseStore):
     async def get_id_for_instance(self, instance_name: str) -> int:
         """Get a unique, immutable ID that corresponds to the given Synapse worker instance."""
 
-        def _get_id_for_instance_txn(txn):
+        def _get_id_for_instance_txn(txn: LoggingTransaction) -> int:
             instance_id = self.db_pool.simple_select_one_onecol_txn(
                 txn,
                 table="instance_map",
diff --git a/synapse/storage/databases/main/tags.py b/synapse/storage/databases/main/tags.py
index c8e508a910..b0f5de67a3 100644
--- a/synapse/storage/databases/main/tags.py
+++ b/synapse/storage/databases/main/tags.py
@@ -97,7 +97,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
         )
 
         def get_tag_content(
-            txn: LoggingTransaction, tag_ids
+            txn: LoggingTransaction, tag_ids: List[Tuple[int, str, str]]
         ) -> List[Tuple[int, Tuple[str, str, str]]]:
             sql = "SELECT tag, content FROM room_tags WHERE user_id=? AND room_id=?"
             results = []
@@ -251,7 +251,7 @@ class TagsWorkerStore(AccountDataWorkerStore):
         return self._account_data_id_gen.get_current_token()
 
     def _update_revision_txn(
-        self, txn, user_id: str, room_id: str, next_id: int
+        self, txn: LoggingTransaction, user_id: str, room_id: str, next_id: int
     ) -> None:
         """Update the latest revision of the tags for the given user and room.