summary refs log tree commit diff
path: root/synapse/storage/databases/main/state.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/state.py')
-rw-r--r--synapse/storage/databases/main/state.py24
1 files changed, 15 insertions, 9 deletions
diff --git a/synapse/storage/databases/main/state.py b/synapse/storage/databases/main/state.py
index 417aef1dbc..28460fd364 100644
--- a/synapse/storage/databases/main/state.py
+++ b/synapse/storage/databases/main/state.py
@@ -14,7 +14,7 @@
 # limitations under the License.
 import collections.abc
 import logging
-from typing import TYPE_CHECKING, Iterable, Optional, Set
+from typing import TYPE_CHECKING, Collection, Iterable, Optional, Set, Tuple
 
 from synapse.api.constants import EventTypes, Membership
 from synapse.api.errors import NotFoundError, UnsupportedRoomVersionError
@@ -29,7 +29,7 @@ from synapse.storage.database import (
 from synapse.storage.databases.main.events_worker import EventsWorkerStore
 from synapse.storage.databases.main.roommember import RoomMemberWorkerStore
 from synapse.storage.state import StateFilter
-from synapse.types import StateMap
+from synapse.types import JsonDict, StateMap
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import cached, cachedList
 
@@ -241,7 +241,9 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
             # We delegate to the cached version
             return await self.get_current_state_ids(room_id)
 
-        def _get_filtered_current_state_ids_txn(txn):
+        def _get_filtered_current_state_ids_txn(
+            txn: LoggingTransaction,
+        ) -> StateMap[str]:
             results = {}
             sql = """
                 SELECT type, state_key, event_id FROM current_state_events
@@ -281,11 +283,11 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
 
         event_id = state.get((EventTypes.CanonicalAlias, ""))
         if not event_id:
-            return
+            return None
 
         event = await self.get_event(event_id, allow_none=True)
         if not event:
-            return
+            return None
 
         return event.content.get("canonical_alias")
 
@@ -304,7 +306,7 @@ class StateGroupWorkerStore(EventsWorkerStore, SQLBaseStore):
         list_name="event_ids",
         num_args=1,
     )
-    async def _get_state_group_for_events(self, event_ids):
+    async def _get_state_group_for_events(self, event_ids: Collection[str]) -> JsonDict:
         """Returns mapping event_id -> state_group"""
         rows = await self.db_pool.simple_select_many_batch(
             table="event_to_state_groups",
@@ -355,7 +357,7 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
     ):
         super().__init__(database, db_conn, hs)
 
-        self.server_name = hs.hostname
+        self.server_name: str = hs.hostname
 
         self.db_pool.updates.register_background_index_update(
             self.CURRENT_STATE_INDEX_UPDATE_NAME,
@@ -375,7 +377,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
             self._background_remove_left_rooms,
         )
 
-    async def _background_remove_left_rooms(self, progress, batch_size):
+    async def _background_remove_left_rooms(
+        self, progress: JsonDict, batch_size: int
+    ) -> int:
         """Background update to delete rows from `current_state_events` and
         `event_forward_extremities` tables of rooms that the server is no
         longer joined to.
@@ -383,7 +387,9 @@ class MainStateBackgroundUpdateStore(RoomMemberWorkerStore):
 
         last_room_id = progress.get("last_room_id", "")
 
-        def _background_remove_left_rooms_txn(txn):
+        def _background_remove_left_rooms_txn(
+            txn: LoggingTransaction,
+        ) -> Tuple[bool, Set[str]]:
             # get a batch of room ids to consider
             sql = """
                 SELECT DISTINCT room_id FROM current_state_events