summary refs log tree commit diff
path: root/synapse/push
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2022-03-25 14:58:56 +0000
committerGitHub <noreply@github.com>2022-03-25 14:58:56 +0000
commit7ca8ee67a5165e33f03454218c81be96397e7591 (patch)
treee847a8e53e76b7dfc330e66f28d3ce6e6a3f73e7 /synapse/push
parentEnhance logging for inbound federation events (#12301) (diff)
downloadsynapse-7ca8ee67a5165e33f03454218c81be96397e7591.tar.xz
Add cache for `get_membership_from_event_ids` (#12272)
This should speed up push rule calculations for rooms with large numbers of local users when the main push rule cache fails.

Co-authored-by: reivilibre <oliverw@matrix.org>
Diffstat (limited to 'synapse/push')
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py30
1 files changed, 16 insertions, 14 deletions
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 030898e4d0..a402a3e403 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -24,6 +24,7 @@ from synapse.event_auth import get_user_power_level
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
 from synapse.state import POWER_KEY
+from synapse.storage.databases.main.roommember import EventIdMembership
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import CacheMetric, register_cache
 from synapse.util.caches.descriptors import lru_cache
@@ -292,7 +293,7 @@ def _condition_checker(
     return True
 
 
-MemberMap = Dict[str, Tuple[str, str]]
+MemberMap = Dict[str, Optional[EventIdMembership]]
 Rule = Dict[str, dict]
 RulesByUser = Dict[str, List[Rule]]
 StateGroup = Union[object, int]
@@ -306,7 +307,7 @@ class RulesForRoomData:
     *only* include data, and not references to e.g. the data stores.
     """
 
-    # event_id -> (user_id, state)
+    # event_id -> EventIdMembership
     member_map: MemberMap = attr.Factory(dict)
     # user_id -> rules
     rules_by_user: RulesByUser = attr.Factory(dict)
@@ -447,11 +448,10 @@ class RulesForRoom:
 
                 res = self.data.member_map.get(event_id, None)
                 if res:
-                    user_id, state = res
-                    if state == Membership.JOIN:
-                        rules = self.data.rules_by_user.get(user_id, None)
+                    if res.membership == Membership.JOIN:
+                        rules = self.data.rules_by_user.get(res.user_id, None)
                         if rules:
-                            ret_rules_by_user[user_id] = rules
+                            ret_rules_by_user[res.user_id] = rules
                     continue
 
                 # If a user has left a room we remove their push rule. If they
@@ -502,24 +502,26 @@ class RulesForRoom:
         """
         sequence = self.data.sequence
 
-        rows = await self.store.get_membership_from_event_ids(member_event_ids.values())
-
-        members = {row["event_id"]: (row["user_id"], row["membership"]) for row in rows}
+        members = await self.store.get_membership_from_event_ids(
+            member_event_ids.values()
+        )
 
-        # If the event is a join event then it will be in current state evnts
+        # If the event is a join event then it will be in current state events
         # map but not in the DB, so we have to explicitly insert it.
         if event.type == EventTypes.Member:
             for event_id in member_event_ids.values():
                 if event_id == event.event_id:
-                    members[event_id] = (event.state_key, event.membership)
+                    members[event_id] = EventIdMembership(
+                        user_id=event.state_key, membership=event.membership
+                    )
 
         if logger.isEnabledFor(logging.DEBUG):
             logger.debug("Found members %r: %r", self.room_id, members.values())
 
         joined_user_ids = {
-            user_id
-            for user_id, membership in members.values()
-            if membership == Membership.JOIN
+            entry.user_id
+            for entry in members.values()
+            if entry and entry.membership == Membership.JOIN
         }
 
         logger.debug("Joined: %r", joined_user_ids)