diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 4cc8a2ecca..90436a043e 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -387,77 +387,27 @@ class RulesForRoom:
self.room_push_rule_cache_metrics.inc_hits()
return self.data.rules_by_user
- self.room_push_rule_cache_metrics.inc_misses()
-
- ret_rules_by_user = {}
- missing_member_event_ids = {}
- if state_group and self.data.state_group == context.prev_group:
- # If we have a simple delta then we can reuse most of the previous
- # results.
- ret_rules_by_user = self.data.rules_by_user
- current_state_ids = context.delta_ids
-
- push_rules_delta_state_cache_metric.inc_hits()
- else:
- current_state_ids = await context.get_current_state_ids()
- push_rules_delta_state_cache_metric.inc_misses()
- # Ensure the state IDs exist.
- assert current_state_ids is not None
-
- push_rules_state_size_counter.inc(len(current_state_ids))
-
- logger.debug(
- "Looking for member changes in %r %r", state_group, current_state_ids
+ local_users = await self.store.get_local_users_in_room(
+ self.room_id, on_invalidate=self.invalidate_all_cb
)
- # Loop through to see which member events we've seen and have rules
- # for and which we need to fetch
- for key in current_state_ids:
- typ, user_id = key
- if typ != EventTypes.Member:
- continue
-
- if user_id in self.data.uninteresting_user_set:
- continue
-
- if not self.is_mine_id(user_id):
- self.data.uninteresting_user_set.add(user_id)
- continue
+ if event.type == EventTypes.Member and event.membership == Membership.JOIN:
+ if self.is_mine_id(event.state_key):
+ local_users = list(local_users)
+ local_users.append(event.state_key)
- if self.store.get_if_app_services_interested_in_user(user_id):
- self.data.uninteresting_user_set.add(user_id)
- continue
+ ret_rules_by_user = await self.store.bulk_get_push_rules(
+ local_users, on_invalidate=self.invalidate_all_cb
+ )
- event_id = current_state_ids[key]
+ logger.info("Users in room: %s", local_users)
- res = self.data.member_map.get(event_id, None)
- if res:
- if res.membership == Membership.JOIN:
- rules = self.data.rules_by_user.get(res.user_id, None)
- if rules:
- ret_rules_by_user[res.user_id] = rules
- continue
-
- # If a user has left a room we remove their push rule. If they
- # joined then we re-add it later in _update_rules_with_member_event_ids
- ret_rules_by_user.pop(user_id, None)
- missing_member_event_ids[user_id] = event_id
-
- if missing_member_event_ids:
- # If we have some member events we haven't seen, look them up
- # and fetch push rules for them if appropriate.
- logger.debug("Found new member events %r", missing_member_event_ids)
- await self._update_rules_with_member_event_ids(
- ret_rules_by_user, missing_member_event_ids, state_group, event
- )
- else:
- # The push rules didn't change but lets update the cache anyway
- self.update_cache(
- self.data.sequence,
- members={}, # There were no membership changes
- rules_by_user=ret_rules_by_user,
- state_group=state_group,
- )
+ self.update_cache(
+ self.data.sequence,
+ members={}, # There were no membership changes
+ rules_by_user=ret_rules_by_user,
+ state_group=state_group,
+ )
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
@@ -465,67 +415,6 @@ class RulesForRoom:
)
return ret_rules_by_user
- async def _update_rules_with_member_event_ids(
- self,
- ret_rules_by_user: Dict[str, list],
- member_event_ids: Dict[str, str],
- state_group: Optional[int],
- event: EventBase,
- ) -> None:
- """Update the partially filled rules_by_user dict by fetching rules for
- any newly joined users in the `member_event_ids` list.
-
- Args:
- ret_rules_by_user: Partially filled dict of push rules. Gets
- updated with any new rules.
- member_event_ids: Dict of user id to event id for membership events
- that have happened since the last time we filled rules_by_user
- state_group: The state group we are currently computing push rules
- for. Used when updating the cache.
- event: The event we are currently computing push rules for.
- """
- sequence = self.data.sequence
-
- members = await self.store.get_membership_from_event_ids(
- member_event_ids.values()
- )
-
- # If the event is a join event then it will be in current state events
- # map but not in the DB, so we have to explicitly insert it.
- if event.type == EventTypes.Member:
- for event_id in member_event_ids.values():
- if event_id == event.event_id:
- members[event_id] = EventIdMembership(
- user_id=event.state_key, membership=event.membership
- )
-
- if logger.isEnabledFor(logging.DEBUG):
- logger.debug("Found members %r: %r", self.room_id, members.values())
-
- joined_user_ids = {
- entry.user_id
- for entry in members.values()
- if entry and entry.membership == Membership.JOIN
- }
-
- logger.debug("Joined: %r", joined_user_ids)
-
- # Previously we only considered users with pushers or read receipts in that
- # room. We can't do this anymore because we use push actions to calculate unread
- # counts, which don't rely on the user having pushers or sent a read receipt into
- # the room. Therefore we just need to filter for local users here.
- user_ids = list(filter(self.is_mine_id, joined_user_ids))
-
- rules_by_user = await self.store.bulk_get_push_rules(
- user_ids, on_invalidate=self.invalidate_all_cb
- )
-
- ret_rules_by_user.update(
- item for item in rules_by_user.items() if item[0] is not None
- )
-
- self.update_cache(sequence, members, ret_rules_by_user, state_group)
-
def update_cache(
self,
sequence: int,
diff --git a/synapse/storage/databases/main/cache.py b/synapse/storage/databases/main/cache.py
index 1653a6a9b6..a07d48f66c 100644
--- a/synapse/storage/databases/main/cache.py
+++ b/synapse/storage/databases/main/cache.py
@@ -217,6 +217,7 @@ class CacheInvalidationWorkerStore(SQLBaseStore):
if etype == EventTypes.Member:
self._membership_stream_cache.entity_has_changed(state_key, stream_ordering)
self.get_invited_rooms_for_local_user.invalidate((state_key,))
+ self.get_local_users_in_room.invalidate((room_id,))
if relates_to:
self.get_relations_for_event.invalidate((relates_to,))
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index 0df8ff5395..28190bf6f5 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1766,6 +1766,10 @@ class PersistEventsStore:
self.store.get_invited_rooms_for_local_user.invalidate,
(event.state_key,),
)
+ txn.call_after(
+ self.store.get_local_users_in_room.invalidate,
+ (event.room_id,),
+ )
# The `_get_membership_from_event_id` is immutable, except for the
# case where we look up an event *before* persisting it.
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index cc528fcf2d..70a30e75b0 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -444,6 +444,15 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return results
+ @cached()
+ async def get_local_users_in_room(self, room_id: str) -> List[str]:
+ return await self.db_pool.simple_select_onecol(
+ table="local_current_membership",
+ keyvalues={"room_id": room_id, "membership": Membership.JOIN},
+ retcol="user_id",
+ desc="get_local_users_in_room",
+ )
+
async def get_local_current_membership_for_user_in_room(
self, user_id: str, room_id: str
) -> Tuple[Optional[str], Optional[str]]:
|