summary refs log tree commit diff
path: root/synapse/storage/databases/main/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/roommember.py')
-rw-r--r--synapse/storage/databases/main/roommember.py78
1 files changed, 43 insertions, 35 deletions
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index bbe08368db..3a87eba430 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -27,6 +27,7 @@ from typing import (
     Set,
     Tuple,
     Union,
+    cast,
 )
 
 import attr
@@ -683,25 +684,28 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
             Map from user_id to set of rooms that is currently in.
         """
 
-        rows = await self.db_pool.simple_select_many_batch(
-            table="current_state_events",
-            column="state_key",
-            iterable=user_ids,
-            retcols=(
-                "state_key",
-                "room_id",
+        rows = cast(
+            List[Tuple[str, str]],
+            await self.db_pool.simple_select_many_batch(
+                table="current_state_events",
+                column="state_key",
+                iterable=user_ids,
+                retcols=(
+                    "state_key",
+                    "room_id",
+                ),
+                keyvalues={
+                    "type": EventTypes.Member,
+                    "membership": Membership.JOIN,
+                },
+                desc="get_rooms_for_users",
             ),
-            keyvalues={
-                "type": EventTypes.Member,
-                "membership": Membership.JOIN,
-            },
-            desc="get_rooms_for_users",
         )
 
         user_rooms: Dict[str, Set[str]] = {user_id: set() for user_id in user_ids}
 
-        for row in rows:
-            user_rooms[row["state_key"]].add(row["room_id"])
+        for state_key, room_id in rows:
+            user_rooms[state_key].add(room_id)
 
         return {key: frozenset(rooms) for key, rooms in user_rooms.items()}
 
@@ -892,17 +896,20 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
             Map from event ID to `user_id`, or None if event is not a join.
         """
 
-        rows = await self.db_pool.simple_select_many_batch(
-            table="room_memberships",
-            column="event_id",
-            iterable=event_ids,
-            retcols=("user_id", "event_id"),
-            keyvalues={"membership": Membership.JOIN},
-            batch_size=1000,
-            desc="_get_user_ids_from_membership_event_ids",
+        rows = cast(
+            List[Tuple[str, str]],
+            await self.db_pool.simple_select_many_batch(
+                table="room_memberships",
+                column="event_id",
+                iterable=event_ids,
+                retcols=("event_id", "user_id"),
+                keyvalues={"membership": Membership.JOIN},
+                batch_size=1000,
+                desc="_get_user_ids_from_membership_event_ids",
+            ),
         )
 
-        return {row["event_id"]: row["user_id"] for row in rows}
+        return dict(rows)
 
     @cached(max_entries=10000)
     async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -1202,21 +1209,22 @@ class RoomMemberWorkerStore(EventsWorkerStore, CacheInvalidationWorkerStore):
             membership event, otherwise the value is None.
         """
 
-        rows = await self.db_pool.simple_select_many_batch(
-            table="room_memberships",
-            column="event_id",
-            iterable=member_event_ids,
-            retcols=("user_id", "membership", "event_id"),
-            keyvalues={},
-            batch_size=500,
-            desc="get_membership_from_event_ids",
+        rows = cast(
+            List[Tuple[str, str, str]],
+            await self.db_pool.simple_select_many_batch(
+                table="room_memberships",
+                column="event_id",
+                iterable=member_event_ids,
+                retcols=("user_id", "membership", "event_id"),
+                keyvalues={},
+                batch_size=500,
+                desc="get_membership_from_event_ids",
+            ),
         )
 
         return {
-            row["event_id"]: EventIdMembership(
-                membership=row["membership"], user_id=row["user_id"]
-            )
-            for row in rows
+            event_id: EventIdMembership(membership=membership, user_id=user_id)
+            for user_id, membership, event_id in rows
         }
 
     async def is_local_host_in_room_ignoring_users(