summary refs log tree commit diff
path: root/synapse/storage/databases/main/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/roommember.py')
-rw-r--r--synapse/storage/databases/main/roommember.py37
1 files changed, 33 insertions, 4 deletions
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index bef675b845..3248da5356 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
 _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class EventIdMembership:
+    """Returned by `get_membership_from_event_ids`"""
+
+    user_id: str
+    membership: str
+
+
 class RoomMemberWorkerStore(EventsWorkerStore):
     def __init__(
         self,
@@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             retcols=("user_id", "display_name", "avatar_url", "event_id"),
             keyvalues={"membership": Membership.JOIN},
             batch_size=500,
-            desc="_get_membership_from_event_ids",
+            desc="_get_joined_profiles_from_event_ids",
         )
 
         return {
@@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return set(room_ids)
 
+    @cached(max_entries=5000)
+    async def _get_membership_from_event_id(
+        self, member_event_id: str
+    ) -> Optional[EventIdMembership]:
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
+    )
     async def get_membership_from_event_ids(
         self, member_event_ids: Iterable[str]
-    ) -> List[dict]:
-        """Get user_id and membership of a set of event IDs."""
+    ) -> Dict[str, Optional[EventIdMembership]]:
+        """Get user_id and membership of a set of event IDs.
+
+        Returns:
+            Mapping from event ID to `EventIdMembership` if the event is a
+            membership event, otherwise the value is None.
+        """
 
-        return await self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=member_event_ids,
@@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             desc="get_membership_from_event_ids",
         )
 
+        return {
+            row["event_id"]: EventIdMembership(
+                membership=row["membership"], user_id=row["user_id"]
+            )
+            for row in rows
+        }
+
     async def is_local_host_in_room_ignoring_users(
         self, room_id: str, ignore_users: Collection[str]
     ) -> bool: