summary refs log tree commit diff
diff options
context:
space:
mode:
authorDirk Klimpel <5740567+dklimpel@users.noreply.github.com>2022-04-27 14:05:00 +0200
committerGitHub <noreply@github.com>2022-04-27 13:05:00 +0100
commitb76f1a4d5f918def1f643910939b80e9e035e07f (patch)
treeb0492ace0e54340b0b40d990a298c2c249274427
parentBound ephemeral events by key (#12544) (diff)
downloadsynapse-b76f1a4d5f918def1f643910939b80e9e035e07f.tar.xz
Add some type hints to datastore (#12485)
-rw-r--r--changelog.d/12485.misc1
-rw-r--r--synapse/storage/databases/main/__init__.py21
-rw-r--r--synapse/storage/databases/main/appservice.py4
-rw-r--r--synapse/storage/databases/main/deviceinbox.py79
-rw-r--r--synapse/storage/databases/main/devices.py51
-rw-r--r--synapse/storage/databases/main/group_server.py4
-rw-r--r--synapse/storage/databases/main/keys.py15
-rw-r--r--synapse/storage/databases/main/media_repository.py13
-rw-r--r--synapse/storage/databases/main/presence.py19
-rw-r--r--synapse/storage/databases/main/pusher.py49
-rw-r--r--synapse/storage/databases/main/state.py4
-rw-r--r--synapse/storage/databases/main/ui_auth.py12
12 files changed, 188 insertions, 84 deletions
diff --git a/changelog.d/12485.misc b/changelog.d/12485.misc
new file mode 100644
index 0000000000..e793d08e5e
--- /dev/null
+++ b/changelog.d/12485.misc
@@ -0,0 +1 @@
+Add some type hints to datastore.
\ No newline at end of file
diff --git a/synapse/storage/databases/main/__init__.py b/synapse/storage/databases/main/__init__.py
index 951031af50..5895b89202 100644
--- a/synapse/storage/databases/main/__init__.py
+++ b/synapse/storage/databases/main/__init__.py
@@ -15,12 +15,17 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, List, Optional, Tuple, cast
 
 from synapse.config.homeserver import HomeServerConfig
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.databases.main.stats import UserSortOrder
-from synapse.storage.engines import PostgresEngine
+from synapse.storage.engines import BaseDatabaseEngine, PostgresEngine
+from synapse.storage.types import Cursor
 from synapse.storage.util.id_generators import (
     IdGenerator,
     MultiWriterIdGenerator,
@@ -266,7 +271,9 @@ class DataStore(
             A tuple of a list of mappings from user to information and a count of total users.
         """
 
-        def get_users_paginate_txn(txn):
+        def get_users_paginate_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[JsonDict], int]:
             filters = []
             args = [self.hs.config.server.server_name]
 
@@ -301,7 +308,7 @@ class DataStore(
                 """
             sql = "SELECT COUNT(*) as total_users " + sql_base
             txn.execute(sql, args)
-            count = txn.fetchone()[0]
+            count = cast(Tuple[int], txn.fetchone())[0]
 
             sql = f"""
                 SELECT name, user_type, is_guest, admin, deactivated, shadow_banned,
@@ -338,7 +345,9 @@ class DataStore(
         )
 
 
-def check_database_before_upgrade(cur, database_engine, config: HomeServerConfig):
+def check_database_before_upgrade(
+    cur: Cursor, database_engine: BaseDatabaseEngine, config: HomeServerConfig
+) -> None:
     """Called before upgrading an existing database to check that it is broadly sane
     compared with the configuration.
     """
diff --git a/synapse/storage/databases/main/appservice.py b/synapse/storage/databases/main/appservice.py
index fa732edcca..945707b0ec 100644
--- a/synapse/storage/databases/main/appservice.py
+++ b/synapse/storage/databases/main/appservice.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import logging
 import re
-from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple
+from typing import TYPE_CHECKING, Any, Dict, List, Optional, Pattern, Tuple, cast
 
 from synapse.appservice import (
     ApplicationService,
@@ -83,7 +83,7 @@ class ApplicationServiceWorkerStore(RoomMemberWorkerStore):
             txn.execute(
                 "SELECT COALESCE(max(txn_id), 0) FROM application_services_txns"
             )
-            return txn.fetchone()[0]  # type: ignore
+            return cast(Tuple[int], txn.fetchone())[0]
 
         self._as_txn_seq_gen = build_sequence_generator(
             db_conn,
diff --git a/synapse/storage/databases/main/deviceinbox.py b/synapse/storage/databases/main/deviceinbox.py
index b4a1b041b1..599b418383 100644
--- a/synapse/storage/databases/main/deviceinbox.py
+++ b/synapse/storage/databases/main/deviceinbox.py
@@ -14,7 +14,17 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Collection, Dict, List, Optional, Set, Tuple, cast
+from typing import (
+    TYPE_CHECKING,
+    Collection,
+    Dict,
+    Iterable,
+    List,
+    Optional,
+    Set,
+    Tuple,
+    cast,
+)
 
 from synapse.logging import issue9533_logger
 from synapse.logging.opentracing import log_kv, set_tag, trace
@@ -118,7 +128,13 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             prefilled_cache=device_outbox_prefill,
         )
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows):
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[ToDeviceStream.ToDeviceStreamRow],
+    ) -> None:
         if stream_name == ToDeviceStream.NAME:
             # If replication is happening than postgres must be being used.
             assert isinstance(self._device_inbox_id_gen, MultiWriterIdGenerator)
@@ -134,7 +150,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                     )
         return super().process_replication_rows(stream_name, instance_name, token, rows)
 
-    def get_to_device_stream_token(self):
+    def get_to_device_stream_token(self) -> int:
         return self._device_inbox_id_gen.get_current_token()
 
     async def get_messages_for_user_devices(
@@ -301,7 +317,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         if not user_ids_to_query:
             return {}, to_stream_id
 
-        def get_device_messages_txn(txn: LoggingTransaction):
+        def get_device_messages_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[Dict[Tuple[str, str], List[JsonDict]], int]:
             # Build a query to select messages from any of the given devices that
             # are between the given stream id bounds.
 
@@ -428,7 +446,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
                 log_kv({"message": "No changes in cache since last check"})
                 return 0
 
-        def delete_messages_for_device_txn(txn):
+        def delete_messages_for_device_txn(txn: LoggingTransaction) -> int:
             sql = (
                 "DELETE FROM device_inbox"
                 " WHERE user_id = ? AND device_id = ?"
@@ -455,15 +473,14 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
     @trace
     async def get_new_device_msgs_for_remote(
-        self, destination, last_stream_id, current_stream_id, limit
-    ) -> Tuple[List[dict], int]:
+        self, destination: str, last_stream_id: int, current_stream_id: int, limit: int
+    ) -> Tuple[List[JsonDict], int]:
         """
         Args:
-            destination(str): The name of the remote server.
-            last_stream_id(int|long): The last position of the device message stream
+            destination: The name of the remote server.
+            last_stream_id: The last position of the device message stream
                 that the server sent up to.
-            current_stream_id(int|long): The current position of the device
-                message stream.
+            current_stream_id: The current position of the device message stream.
         Returns:
             A list of messages for the device and where in the stream the messages got to.
         """
@@ -485,7 +502,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             return [], last_stream_id
 
         @trace
-        def get_new_messages_for_remote_destination_txn(txn):
+        def get_new_messages_for_remote_destination_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[JsonDict], int]:
             sql = (
                 "SELECT stream_id, messages_json FROM device_federation_outbox"
                 " WHERE destination = ?"
@@ -527,7 +546,7 @@ class DeviceInboxWorkerStore(SQLBaseStore):
             up_to_stream_id: Where to delete messages up to.
         """
 
-        def delete_messages_for_remote_destination_txn(txn):
+        def delete_messages_for_remote_destination_txn(txn: LoggingTransaction) -> None:
             sql = (
                 "DELETE FROM device_federation_outbox"
                 " WHERE destination = ?"
@@ -566,7 +585,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_new_device_messages_txn(txn):
+        def get_all_new_device_messages_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             # We limit like this as we might have multiple rows per stream_id, and
             # we want to make sure we always get all entries for any stream_id
             # we return.
@@ -607,8 +628,8 @@ class DeviceInboxWorkerStore(SQLBaseStore):
     @trace
     async def add_messages_to_device_inbox(
         self,
-        local_messages_by_user_then_device: dict,
-        remote_messages_by_destination: dict,
+        local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
+        remote_messages_by_destination: Dict[str, JsonDict],
     ) -> int:
         """Used to send messages from this server.
 
@@ -624,7 +645,9 @@ class DeviceInboxWorkerStore(SQLBaseStore):
 
         assert self._can_write_to_device
 
-        def add_messages_txn(txn, now_ms, stream_id):
+        def add_messages_txn(
+            txn: LoggingTransaction, now_ms: int, stream_id: int
+        ) -> None:
             # Add the local messages directly to the local inbox.
             self._add_messages_to_local_device_inbox_txn(
                 txn, stream_id, local_messages_by_user_then_device
@@ -677,11 +700,16 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         return self._device_inbox_id_gen.get_current_token()
 
     async def add_messages_from_remote_to_device_inbox(
-        self, origin: str, message_id: str, local_messages_by_user_then_device: dict
+        self,
+        origin: str,
+        message_id: str,
+        local_messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
     ) -> int:
         assert self._can_write_to_device
 
-        def add_messages_txn(txn, now_ms, stream_id):
+        def add_messages_txn(
+            txn: LoggingTransaction, now_ms: int, stream_id: int
+        ) -> None:
             # Check if we've already inserted a matching message_id for that
             # origin. This can happen if the origin doesn't receive our
             # acknowledgement from the first time we received the message.
@@ -727,8 +755,11 @@ class DeviceInboxWorkerStore(SQLBaseStore):
         return stream_id
 
     def _add_messages_to_local_device_inbox_txn(
-        self, txn, stream_id, messages_by_user_then_device
-    ):
+        self,
+        txn: LoggingTransaction,
+        stream_id: int,
+        messages_by_user_then_device: Dict[str, Dict[str, JsonDict]],
+    ) -> None:
         assert self._can_write_to_device
 
         local_by_user_then_device = {}
@@ -840,8 +871,10 @@ class DeviceInboxBackgroundUpdateStore(SQLBaseStore):
             self._remove_dead_devices_from_device_inbox,
         )
 
-    async def _background_drop_index_device_inbox(self, progress, batch_size):
-        def reindex_txn(conn):
+    async def _background_drop_index_device_inbox(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        def reindex_txn(conn: LoggingDatabaseConnection) -> None:
             txn = conn.cursor()
             txn.execute("DROP INDEX IF EXISTS device_inbox_stream_id")
             txn.close()
diff --git a/synapse/storage/databases/main/devices.py b/synapse/storage/databases/main/devices.py
index 483dd80406..2df4dd4ed4 100644
--- a/synapse/storage/databases/main/devices.py
+++ b/synapse/storage/databases/main/devices.py
@@ -25,6 +25,7 @@ from typing import (
     Optional,
     Set,
     Tuple,
+    cast,
 )
 
 from synapse.api.errors import Codes, StoreError
@@ -136,7 +137,9 @@ class DeviceWorkerStore(SQLBaseStore):
             Number of devices of this users.
         """
 
-        def count_devices_by_users_txn(txn, user_ids):
+        def count_devices_by_users_txn(
+            txn: LoggingTransaction, user_ids: List[str]
+        ) -> int:
             sql = """
                 SELECT count(*)
                 FROM devices
@@ -149,7 +152,7 @@ class DeviceWorkerStore(SQLBaseStore):
             )
 
             txn.execute(sql + clause, args)
-            return txn.fetchone()[0]
+            return cast(Tuple[int], txn.fetchone())[0]
 
         if not user_ids:
             return 0
@@ -468,7 +471,7 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         txn.execute(sql, (destination, from_stream_id, now_stream_id, limit))
 
-        return list(txn)
+        return cast(List[Tuple[str, str, int, Optional[str]]], txn.fetchall())
 
     async def _get_device_update_edus_by_remote(
         self,
@@ -549,7 +552,7 @@ class DeviceWorkerStore(SQLBaseStore):
     async def _get_last_device_update_for_remote_user(
         self, destination: str, user_id: str, from_stream_id: int
     ) -> int:
-        def f(txn):
+        def f(txn: LoggingTransaction) -> int:
             prev_sent_id_sql = """
                 SELECT coalesce(max(stream_id), 0) as stream_id
                 FROM device_lists_outbound_last_success
@@ -767,7 +770,7 @@ class DeviceWorkerStore(SQLBaseStore):
         if not user_ids_to_check:
             return set()
 
-        def _get_users_whose_devices_changed_txn(txn):
+        def _get_users_whose_devices_changed_txn(txn: LoggingTransaction) -> Set[str]:
             changes = set()
 
             stream_id_where_clause = "stream_id > ?"
@@ -966,7 +969,9 @@ class DeviceWorkerStore(SQLBaseStore):
     async def mark_remote_user_device_list_as_unsubscribed(self, user_id: str) -> None:
         """Mark that we no longer track device lists for remote user."""
 
-        def _mark_remote_user_device_list_as_unsubscribed_txn(txn):
+        def _mark_remote_user_device_list_as_unsubscribed_txn(
+            txn: LoggingTransaction,
+        ) -> None:
             self.db_pool.simple_delete_txn(
                 txn,
                 table="device_lists_remote_extremeties",
@@ -1004,7 +1009,7 @@ class DeviceWorkerStore(SQLBaseStore):
         )
 
     def _store_dehydrated_device_txn(
-        self, txn, user_id: str, device_id: str, device_data: str
+        self, txn: LoggingTransaction, user_id: str, device_id: str, device_data: str
     ) -> Optional[str]:
         old_device_id = self.db_pool.simple_select_one_onecol_txn(
             txn,
@@ -1081,7 +1086,7 @@ class DeviceWorkerStore(SQLBaseStore):
         """
         yesterday = self._clock.time_msec() - prune_age
 
-        def _prune_txn(txn):
+        def _prune_txn(txn: LoggingTransaction) -> None:
             # look for (user, destination) pairs which have an update older than
             # the cutoff.
             #
@@ -1204,8 +1209,10 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
             "drop_device_lists_outbound_last_success_non_unique_idx",
         )
 
-    async def _drop_device_list_streams_non_unique_indexes(self, progress, batch_size):
-        def f(conn):
+    async def _drop_device_list_streams_non_unique_indexes(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
+        def f(conn: LoggingDatabaseConnection) -> None:
             txn = conn.cursor()
             txn.execute("DROP INDEX IF EXISTS device_lists_remote_cache_id")
             txn.execute("DROP INDEX IF EXISTS device_lists_remote_extremeties_id")
@@ -1217,7 +1224,9 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
         )
         return 1
 
-    async def _remove_duplicate_outbound_pokes(self, progress, batch_size):
+    async def _remove_duplicate_outbound_pokes(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         # for some reason, we have accumulated duplicate entries in
         # device_lists_outbound_pokes, which makes prune_outbound_device_list_pokes less
         # efficient.
@@ -1230,7 +1239,7 @@ class DeviceBackgroundUpdateStore(SQLBaseStore):
             {"stream_id": 0, "destination": "", "user_id": "", "device_id": ""},
         )
 
-        def _txn(txn):
+        def _txn(txn: LoggingTransaction) -> int:
             clause, args = make_tuple_comparison_clause(
                 [(x, last_row[x]) for x in KEY_COLS]
             )
@@ -1602,7 +1611,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
 
         context = get_active_span_text_map()
 
-        def add_device_changes_txn(txn, stream_ids):
+        def add_device_changes_txn(
+            txn: LoggingTransaction, stream_ids: List[int]
+        ) -> None:
             self._add_device_change_to_stream_txn(
                 txn,
                 user_id,
@@ -1635,8 +1646,8 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         txn: LoggingTransaction,
         user_id: str,
         device_ids: Collection[str],
-        stream_ids: List[str],
-    ):
+        stream_ids: List[int],
+    ) -> None:
         txn.call_after(
             self._device_list_stream_cache.entity_has_changed,
             user_id,
@@ -1720,7 +1731,7 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         user_id: str,
         device_ids: Iterable[str],
         room_ids: Collection[str],
-        stream_ids: List[str],
+        stream_ids: List[int],
         context: Dict[str, str],
     ) -> None:
         """Record the user in the room has updated their device."""
@@ -1775,7 +1786,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
             LIMIT ?
         """
 
-        def get_uncoverted_outbound_room_pokes_txn(txn):
+        def get_uncoverted_outbound_room_pokes_txn(
+            txn: LoggingTransaction,
+        ) -> List[Tuple[str, str, str, int, Optional[Dict[str, str]]]]:
             txn.execute(sql, (limit,))
 
             return [
@@ -1808,7 +1821,9 @@ class DeviceStore(DeviceWorkerStore, DeviceBackgroundUpdateStore):
         Marks the associated row in `device_lists_changes_in_room` as handled.
         """
 
-        def add_device_list_outbound_pokes_txn(txn, stream_ids: List[int]):
+        def add_device_list_outbound_pokes_txn(
+            txn: LoggingTransaction, stream_ids: List[int]
+        ) -> None:
             if hosts:
                 self._add_device_outbound_poke_to_stream_txn(
                     txn,
diff --git a/synapse/storage/databases/main/group_server.py b/synapse/storage/databases/main/group_server.py
index 0aef121d83..04efad9e9a 100644
--- a/synapse/storage/databases/main/group_server.py
+++ b/synapse/storage/databases/main/group_server.py
@@ -522,7 +522,9 @@ class GroupServerWorkerStore(SQLBaseStore):
             desc="get_joined_groups",
         )
 
-    async def get_all_groups_for_user(self, user_id, now_token) -> List[JsonDict]:
+    async def get_all_groups_for_user(
+        self, user_id: str, now_token: int
+    ) -> List[JsonDict]:
         def _get_all_groups_for_user_txn(txn: LoggingTransaction) -> List[JsonDict]:
             sql = """
                 SELECT group_id, type, membership, u.content
diff --git a/synapse/storage/databases/main/keys.py b/synapse/storage/databases/main/keys.py
index 6990f3ed1d..0a19f607bd 100644
--- a/synapse/storage/databases/main/keys.py
+++ b/synapse/storage/databases/main/keys.py
@@ -15,11 +15,12 @@
 
 import itertools
 import logging
-from typing import Dict, Iterable, List, Optional, Tuple
+from typing import Any, Dict, Iterable, List, Optional, Tuple
 
 from signedjson.key import decode_verify_key_bytes
 
 from synapse.storage._base import SQLBaseStore
+from synapse.storage.database import LoggingTransaction
 from synapse.storage.keys import FetchKeyResult
 from synapse.storage.types import Cursor
 from synapse.util.caches.descriptors import cached, cachedList
@@ -35,7 +36,9 @@ class KeyStore(SQLBaseStore):
     """Persistence for signature verification keys"""
 
     @cached()
-    def _get_server_verify_key(self, server_name_and_key_id):
+    def _get_server_verify_key(
+        self, server_name_and_key_id: Tuple[str, str]
+    ) -> FetchKeyResult:
         raise NotImplementedError()
 
     @cachedList(
@@ -179,19 +182,21 @@ class KeyStore(SQLBaseStore):
 
     async def get_server_keys_json(
         self, server_keys: Iterable[Tuple[str, Optional[str], Optional[str]]]
-    ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[dict]]:
+    ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
         """Retrieve the key json for a list of server_keys and key ids.
         If no keys are found for a given server, key_id and source then
         that server, key_id, and source triplet entry will be an empty list.
         The JSON is returned as a byte array so that it can be efficiently
         used in an HTTP response.
         Args:
-            server_keys (list): List of (server_name, key_id, source) triplets.
+            server_keys: List of (server_name, key_id, source) triplets.
         Returns:
             A mapping from (server_name, key_id, source) triplets to a list of dicts
         """
 
-        def _get_server_keys_json_txn(txn):
+        def _get_server_keys_json_txn(
+            txn: LoggingTransaction,
+        ) -> Dict[Tuple[str, Optional[str], Optional[str]], List[Dict[str, Any]]]:
             results = {}
             for server_name, key_id, from_server in server_keys:
                 keyvalues = {"server_name": server_name}
diff --git a/synapse/storage/databases/main/media_repository.py b/synapse/storage/databases/main/media_repository.py
index 322ed05390..40ac377ca9 100644
--- a/synapse/storage/databases/main/media_repository.py
+++ b/synapse/storage/databases/main/media_repository.py
@@ -388,7 +388,14 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         return await self.db_pool.runInteraction("get_url_cache", get_url_cache_txn)
 
     async def store_url_cache(
-        self, url, response_code, etag, expires_ts, og, media_id, download_ts
+        self,
+        url: str,
+        response_code: int,
+        etag: Optional[str],
+        expires_ts: int,
+        og: Optional[str],
+        media_id: str,
+        download_ts: int,
     ) -> None:
         await self.db_pool.simple_insert(
             "local_media_repository_url_cache",
@@ -441,7 +448,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         )
 
     async def get_cached_remote_media(
-        self, origin, media_id: str
+        self, origin: str, media_id: str
     ) -> Optional[Dict[str, Any]]:
         return await self.db_pool.simple_select_one(
             "remote_media_cache",
@@ -608,7 +615,7 @@ class MediaRepositoryStore(MediaRepositoryBackgroundUpdateStore):
         )
 
     async def delete_remote_media(self, media_origin: str, media_id: str) -> None:
-        def delete_remote_media_txn(txn):
+        def delete_remote_media_txn(txn: LoggingTransaction) -> None:
             self.db_pool.simple_delete_txn(
                 txn,
                 "remote_media_cache",
diff --git a/synapse/storage/databases/main/presence.py b/synapse/storage/databases/main/presence.py
index d3c4611686..b47c511450 100644
--- a/synapse/storage/databases/main/presence.py
+++ b/synapse/storage/databases/main/presence.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import TYPE_CHECKING, Dict, Iterable, List, Tuple, cast
+from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Tuple, cast
 
 from synapse.api.presence import PresenceState, UserPresenceState
 from synapse.replication.tcp.streams import PresenceStream
@@ -103,7 +103,9 @@ class PresenceStore(PresenceBackgroundUpdateStore):
             prefilled_cache=presence_cache_prefill,
         )
 
-    async def update_presence(self, presence_states) -> Tuple[int, int]:
+    async def update_presence(
+        self, presence_states: List[UserPresenceState]
+    ) -> Tuple[int, int]:
         assert self._can_persist_presence
 
         stream_ordering_manager = self._presence_id_gen.get_next_mult(
@@ -121,7 +123,10 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         return stream_orderings[-1], self._presence_id_gen.get_current_token()
 
     def _update_presence_txn(
-        self, txn: LoggingTransaction, stream_orderings, presence_states
+        self,
+        txn: LoggingTransaction,
+        stream_orderings: List[int],
+        presence_states: List[UserPresenceState],
     ) -> None:
         for stream_id, state in zip(stream_orderings, presence_states):
             txn.call_after(
@@ -405,7 +410,13 @@ class PresenceStore(PresenceBackgroundUpdateStore):
         self._presence_on_startup = []
         return active_on_startup
 
-    def process_replication_rows(self, stream_name, instance_name, token, rows) -> None:
+    def process_replication_rows(
+        self,
+        stream_name: str,
+        instance_name: str,
+        token: int,
+        rows: Iterable[Any],
+    ) -> None:
         if stream_name == PresenceStream.NAME:
             self._presence_id_gen.advance(instance_name, token)
             for row in rows:
diff --git a/synapse/storage/databases/main/pusher.py b/synapse/storage/databases/main/pusher.py
index cf64cd63a4..91286c9b65 100644
--- a/synapse/storage/databases/main/pusher.py
+++ b/synapse/storage/databases/main/pusher.py
@@ -14,11 +14,25 @@
 # limitations under the License.
 
 import logging
-from typing import TYPE_CHECKING, Any, Dict, Iterable, Iterator, List, Optional, Tuple
+from typing import (
+    TYPE_CHECKING,
+    Any,
+    Dict,
+    Iterable,
+    Iterator,
+    List,
+    Optional,
+    Tuple,
+    cast,
+)
 
 from synapse.push import PusherConfig, ThrottleParams
 from synapse.storage._base import SQLBaseStore, db_to_json
-from synapse.storage.database import DatabasePool, LoggingDatabaseConnection
+from synapse.storage.database import (
+    DatabasePool,
+    LoggingDatabaseConnection,
+    LoggingTransaction,
+)
 from synapse.storage.util.id_generators import StreamIdGenerator
 from synapse.types import JsonDict
 from synapse.util import json_encoder
@@ -117,7 +131,7 @@ class PusherWorkerStore(SQLBaseStore):
         return self._decode_pushers_rows(ret)
 
     async def get_all_pushers(self) -> Iterator[PusherConfig]:
-        def get_pushers(txn):
+        def get_pushers(txn: LoggingTransaction) -> Iterator[PusherConfig]:
             txn.execute("SELECT * FROM pushers")
             rows = self.db_pool.cursor_to_dict(txn)
 
@@ -152,7 +166,9 @@ class PusherWorkerStore(SQLBaseStore):
         if last_id == current_id:
             return [], current_id, False
 
-        def get_all_updated_pushers_rows_txn(txn):
+        def get_all_updated_pushers_rows_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[List[Tuple[int, tuple]], int, bool]:
             sql = """
                 SELECT id, user_name, app_id, pushkey
                 FROM pushers
@@ -160,10 +176,13 @@ class PusherWorkerStore(SQLBaseStore):
                 ORDER BY id ASC LIMIT ?
             """
             txn.execute(sql, (last_id, current_id, limit))
-            updates = [
-                (stream_id, (user_name, app_id, pushkey, False))
-                for stream_id, user_name, app_id, pushkey in txn
-            ]
+            updates = cast(
+                List[Tuple[int, tuple]],
+                [
+                    (stream_id, (user_name, app_id, pushkey, False))
+                    for stream_id, user_name, app_id, pushkey in txn
+                ],
+            )
 
             sql = """
                 SELECT stream_id, user_id, app_id, pushkey
@@ -192,12 +211,12 @@ class PusherWorkerStore(SQLBaseStore):
         )
 
     @cached(num_args=1, max_entries=15000)
-    async def get_if_user_has_pusher(self, user_id: str):
+    async def get_if_user_has_pusher(self, user_id: str) -> None:
         # This only exists for the cachedList decorator
         raise NotImplementedError()
 
     async def update_pusher_last_stream_ordering(
-        self, app_id, pushkey, user_id, last_stream_ordering
+        self, app_id: str, pushkey: str, user_id: str, last_stream_ordering: int
     ) -> None:
         await self.db_pool.simple_update_one(
             "pushers",
@@ -291,7 +310,7 @@ class PusherWorkerStore(SQLBaseStore):
 
         last_user = progress.get("last_user", "")
 
-        def _delete_pushers(txn) -> int:
+        def _delete_pushers(txn: LoggingTransaction) -> int:
 
             sql = """
                 SELECT name FROM users
@@ -339,7 +358,7 @@ class PusherWorkerStore(SQLBaseStore):
 
         last_pusher = progress.get("last_pusher", 0)
 
-        def _delete_pushers(txn) -> int:
+        def _delete_pushers(txn: LoggingTransaction) -> int:
 
             sql = """
                 SELECT p.id, access_token FROM pushers AS p
@@ -396,7 +415,7 @@ class PusherWorkerStore(SQLBaseStore):
 
         last_pusher = progress.get("last_pusher", 0)
 
-        def _delete_pushers(txn) -> int:
+        def _delete_pushers(txn: LoggingTransaction) -> int:
 
             sql = """
                 SELECT p.id, p.user_name, p.app_id, p.pushkey
@@ -502,7 +521,7 @@ class PusherStore(PusherWorkerStore):
     async def delete_pusher_by_app_id_pushkey_user_id(
         self, app_id: str, pushkey: str, user_id: str
     ) -> None:
-        def delete_pusher_txn(txn, stream_id):
+        def delete_pusher_txn(txn: LoggingTransaction, stream_id: int) -> None:
             self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
@@ -547,7 +566,7 @@ class PusherStore(PusherWorkerStore):
         # account.
         pushers = list(await self.get_pushers_by_user_id(user_id))
 
-        def delete_pushers_txn(txn, stream_ids):
+        def delete_pushers_txn(txn: LoggingTransaction, stream_ids: List[int]) -> None:
             self._invalidate_cache_and_stream(  # type: ignore[attr-defined]
                 txn, self.get_if_user_has_pusher, (user_id,)
             )
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index e653841fe5..5e340bd03d 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -370,10 +370,10 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
     def _update_state_for_partial_state_event_txn(
         self,
-        txn,
+        txn: LoggingTransaction,
         event: EventBase,
         context: EventContext,
-    ):
+    ) -> None:
         # we shouldn't have any outliers here
         assert not event.internal_metadata.is_outlier()
 
diff --git a/synapse/storage/databases/main/ui_auth.py b/synapse/storage/databases/main/ui_auth.py
index 2d339b6008..f38bedbbcd 100644
--- a/synapse/storage/databases/main/ui_auth.py
+++ b/synapse/storage/databases/main/ui_auth.py
@@ -131,7 +131,7 @@ class UIAuthWorkerStore(SQLBaseStore):
         session_id: str,
         stage_type: str,
         result: Union[str, bool, JsonDict],
-    ):
+    ) -> None:
         """
         Mark a session stage as completed.
 
@@ -200,7 +200,9 @@ class UIAuthWorkerStore(SQLBaseStore):
             desc="set_ui_auth_client_dict",
         )
 
-    async def set_ui_auth_session_data(self, session_id: str, key: str, value: Any):
+    async def set_ui_auth_session_data(
+        self, session_id: str, key: str, value: Any
+    ) -> None:
         """
         Store a key-value pair into the sessions data associated with this
         request. This data is stored server-side and cannot be modified by
@@ -223,7 +225,7 @@ class UIAuthWorkerStore(SQLBaseStore):
 
     def _set_ui_auth_session_data_txn(
         self, txn: LoggingTransaction, session_id: str, key: str, value: Any
-    ):
+    ) -> None:
         # Get the current value.
         result = cast(
             Dict[str, Any],
@@ -275,7 +277,7 @@ class UIAuthWorkerStore(SQLBaseStore):
         session_id: str,
         user_agent: str,
         ip: str,
-    ):
+    ) -> None:
         """Add the given user agent / IP to the tracking table"""
         await self.db_pool.simple_upsert(
             table="ui_auth_sessions_ips",
@@ -318,7 +320,7 @@ class UIAuthWorkerStore(SQLBaseStore):
 
     def _delete_old_ui_auth_sessions_txn(
         self, txn: LoggingTransaction, expiration_time: int
-    ):
+    ) -> None:
         # Get the expired sessions.
         sql = "SELECT session_id FROM ui_auth_sessions WHERE creation_time <= ?"
         txn.execute(sql, [expiration_time])