summary refs log tree commit diff
path: root/synapse/storage/databases
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases')
-rw-r--r--synapse/storage/databases/main/cache.py4
-rw-r--r--synapse/storage/databases/main/events.py7
-rw-r--r--synapse/storage/databases/main/roommember.py37
3 files changed, 44 insertions, 4 deletions
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2d7511d613..dd4e83a2ad 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -192,6 +192,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
 
+        # The `_get_membership_from_event_id` is immutable, except for the
+        # case where we look up an event *before* persisting it.
+        self._get_membership_from_event_id.invalidate((event_id,))
+
         if not backfilled:
             self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1f60aef180..d253243125 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1745,6 +1745,13 @@ class PersistEventsStore:
                 (event.state_key,),
             )
 
+            # The `_get_membership_from_event_id` is immutable, except for the
+            # case where we look up an event *before* persisting it.
+            txn.call_after(
+                self.store._get_membership_from_event_id.invalidate,
+                (event.event_id,),
+            )
+
             # We update the local_current_membership table only if the event is
             # "current", i.e., its something that has just happened.
             #
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index bef675b845..3248da5356 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
 _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class EventIdMembership:
+    """Returned by `get_membership_from_event_ids`"""
+
+    user_id: str
+    membership: str
+
+
 class RoomMemberWorkerStore(EventsWorkerStore):
     def __init__(
         self,
@@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             retcols=("user_id", "display_name", "avatar_url", "event_id"),
             keyvalues={"membership": Membership.JOIN},
             batch_size=500,
-            desc="_get_membership_from_event_ids",
+            desc="_get_joined_profiles_from_event_ids",
         )
 
         return {
@@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return set(room_ids)
 
+    @cached(max_entries=5000)
+    async def _get_membership_from_event_id(
+        self, member_event_id: str
+    ) -> Optional[EventIdMembership]:
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
+    )
     async def get_membership_from_event_ids(
         self, member_event_ids: Iterable[str]
-    ) -> List[dict]:
-        """Get user_id and membership of a set of event IDs."""
+    ) -> Dict[str, Optional[EventIdMembership]]:
+        """Get user_id and membership of a set of event IDs.
+
+        Returns:
+            Mapping from event ID to `EventIdMembership` if the event is a
+            membership event, otherwise the value is None.
+        """
 
-        return await self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=member_event_ids,
@@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             desc="get_membership_from_event_ids",
         )
 
+        return {
+            row["event_id"]: EventIdMembership(
+                membership=row["membership"], user_id=row["user_id"]
+            )
+            for row in rows
+        }
+
     async def is_local_host_in_room_ignoring_users(
         self, room_id: str, ignore_users: Collection[str]
     ) -> bool: