summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py30
-rw-r--r--synapse/storage/databases/main/cache.py4
-rw-r--r--synapse/storage/databases/main/events.py7
-rw-r--r--synapse/storage/databases/main/roommember.py37
-rw-r--r--synapse/storage/persist_events.py15
5 files changed, 71 insertions, 22 deletions
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 030898e4d0..a402a3e403 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -24,6 +24,7 @@ from synapse.event_auth import get_user_power_level
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.state import POWER_KEY
+from synapse.storage.databases.main.roommember import EventIdMembership
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import CacheMetric, register_cache
 from synapse.util.caches.descriptors import lru_cache
@@ -292,7 +293,7 @@ def _condition_checker(
     return True
 
 
-MemberMap = Dict[str, Tuple[str, str]]
+MemberMap = Dict[str, Optional[EventIdMembership]]
 Rule = Dict[str, dict]
 RulesByUser = Dict[str, List[Rule]]
 StateGroup = Union[object, int]
@@ -306,7 +307,7 @@ class RulesForRoomData:
     *only* include data, and not references to e.g. the data stores.
     """
 
-    # event_id -> (user_id, state)
+    # event_id -> EventIdMembership
     member_map: MemberMap = attr.Factory(dict)
     # user_id -> rules
     rules_by_user: RulesByUser = attr.Factory(dict)
@@ -447,11 +448,10 @@ class RulesForRoom:
 
                 res = self.data.member_map.get(event_id, None)
                 if res:
-                    user_id, state = res
-                    if state == Membership.JOIN:
-                        rules = self.data.rules_by_user.get(user_id, None)
+                    if res.membership == Membership.JOIN:
+                        rules = self.data.rules_by_user.get(res.user_id, None)
                         if rules:
-                            ret_rules_by_user[user_id] = rules
+                            ret_rules_by_user[res.user_id] = rules
                     continue
 
                 # If a user has left a room we remove their push rule. If they
@@ -502,24 +502,26 @@ class RulesForRoom:
         """
         sequence = self.data.sequence
 
-        rows = await self.store.get_membership_from_event_ids(member_event_ids.values())
-
-        members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
+        members = await self.store.get_membership_from_event_ids(
+            member_event_ids.values()
+        )
 
-        # If the event is a join event then it will be in current state evnts
+        # If the event is a join event then it will be in current state events
         # map but not in the DB, so we have to explicitly insert it.
         if event.type == EventTypes.Member:
             for event_id in member_event_ids.values():
                 if event_id == event.event_id:
-                    members[event_id] = (event.state_key, event.membership)
+                    members[event_id] = EventIdMembership(
+                        user_id=event.state_key, membership=event.membership
+                    )
 
         if logger.isEnabledFor(logging.DEBUG):
             logger.debug("Found members %r: %r", self.room_id, members.values())
 
         joined_user_ids = {
-            user_id
-            for user_id, membership in members.values()
-            if membership == Membership.JOIN
+            entry.user_id
+            for entry in members.values()
+            if entry and entry.membership == Membership.JOIN
         }
 
         logger.debug("Joined: %r", joined_user_ids)
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 2d7511d613..dd4e83a2ad 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -192,6 +192,10 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
 
         self.get_unread_event_push_actions_by_room_for_user.invalidate((room_id,))
 
+        # The `_get_membership_from_event_id` is immutable, except for the
+        # case where we look up an event *before* persisting it.
+        self._get_membership_from_event_id.invalidate((event_id,))
+
         if not backfilled:
             self._events_stream_cache.entity_has_changed(room_id, stream_ordering)
 
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 1f60aef180..d253243125 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1745,6 +1745,13 @@ class PersistEventsStore:
                 (event.state_key,),
             )
 
+            # The `_get_membership_from_event_id` is immutable, except for the
+            # case where we look up an event *before* persisting it.
+            txn.call_after(
+                self.store._get_membership_from_event_id.invalidate,
+                (event.event_id,),
+            )
+
             # We update the local_current_membership table only if the event is
             # "current", i.e., its something that has just happened.
             #
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index bef675b845..3248da5356 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -63,6 +63,14 @@ _MEMBERSHIP_PROFILE_UPDATE_NAME = "room_membership_profile_update"
 _CURRENT_STATE_MEMBERSHIP_UPDATE_NAME = "current_state_events_membership"
 
 
+@attr.s(frozen=True, slots=True, auto_attribs=True)
+class EventIdMembership:
+    """Returned by `get_membership_from_event_ids`"""
+
+    user_id: str
+    membership: str
+
+
 class RoomMemberWorkerStore(EventsWorkerStore):
     def __init__(
         self,
@@ -772,7 +780,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             retcols=("user_id", "display_name", "avatar_url", "event_id"),
             keyvalues={"membership": Membership.JOIN},
             batch_size=500,
-            desc="_get_membership_from_event_ids",
+            desc="_get_joined_profiles_from_event_ids",
         )
 
         return {
@@ -1000,12 +1008,26 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return set(room_ids)
 
+    @cached(max_entries=5000)
+    async def _get_membership_from_event_id(
+        self, member_event_id: str
+    ) -> Optional[EventIdMembership]:
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="_get_membership_from_event_id", list_name="member_event_ids"
+    )
     async def get_membership_from_event_ids(
         self, member_event_ids: Iterable[str]
-    ) -> List[dict]:
-        """Get user_id and membership of a set of event IDs."""
+    ) -> Dict[str, Optional[EventIdMembership]]:
+        """Get user_id and membership of a set of event IDs.
+
+        Returns:
+            Mapping from event ID to `EventIdMembership` if the event is a
+            membership event, otherwise the value is None.
+        """
 
-        return await self.db_pool.simple_select_many_batch(
+        rows = await self.db_pool.simple_select_many_batch(
             table="room_memberships",
             column="event_id",
             iterable=member_event_ids,
@@ -1015,6 +1037,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             desc="get_membership_from_event_ids",
         )
 
+        return {
+            row["event_id"]: EventIdMembership(
+                membership=row["membership"], user_id=row["user_id"]
+            )
+            for row in rows
+        }
+
     async def is_local_host_in_room_ignoring_users(
         self, room_id: str, ignore_users: Collection[str]
     ) -> bool:
diff --git a/synapse/storage/persist_events.py b/synapse/storage/persist_events.py
index 7d543fdbe0..b402922817 100644
--- a/synapse/storage/persist_events.py
+++ b/synapse/storage/persist_events.py
@@ -1023,8 +1023,13 @@ class EventsPersistenceStorage:
 
         # Check if any of the changes that we don't have events for are joins.
         if events_to_check:
-            rows = await self.main_store.get_membership_from_event_ids(events_to_check)
-            is_still_joined = any(row["membership"] == Membership.JOIN for row in rows)
+            members = await self.main_store.get_membership_from_event_ids(
+                events_to_check
+            )
+            is_still_joined = any(
+                member and member.membership == Membership.JOIN
+                for member in members.values()
+            )
             if is_still_joined:
                 return True
 
@@ -1060,9 +1065,11 @@ class EventsPersistenceStorage:
             ), event_id in current_state.items()
             if typ == EventTypes.Member and not self.is_mine_id(state_key)
         ]
-        rows = await self.main_store.get_membership_from_event_ids(remote_event_ids)
+        members = await self.main_store.get_membership_from_event_ids(remote_event_ids)
         potentially_left_users.update(
-            row["user_id"] for row in rows if row["membership"] == Membership.JOIN
+            member.user_id
+            for member in members.values()
+            if member and member.membership == Membership.JOIN
         )
 
         return False