diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 50b470c310..350646f458 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -106,6 +106,10 @@ class BulkPushRuleEvaluator:
self.store = hs.get_datastore()
self.auth = hs.get_auth()
+ # Used by `RulesForRoom` to ensure only one thing mutates the cache at a
+ # time. Keyed off room_id.
+ self._rules_linearizer = Linearizer(name="rules_for_room")
+
self.room_push_rule_cache_metrics = register_cache(
"cache",
"room_push_rule_cache",
@@ -123,7 +127,16 @@ class BulkPushRuleEvaluator:
dict of user_id -> push_rules
"""
room_id = event.room_id
- rules_for_room = self._get_rules_for_room(room_id)
+
+ rules_for_room_data = self._get_rules_for_room(room_id)
+ rules_for_room = RulesForRoom(
+ hs=self.hs,
+ room_id=room_id,
+ rules_for_room_cache=self._get_rules_for_room.cache,
+ room_push_rule_cache_metrics=self.room_push_rule_cache_metrics,
+ linearizer=self._rules_linearizer,
+ cached_data=rules_for_room_data,
+ )
rules_by_user = await rules_for_room.get_rules(event, context)
@@ -142,17 +155,12 @@ class BulkPushRuleEvaluator:
return rules_by_user
@lru_cache()
- def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
- """Get the current RulesForRoom object for the given room id"""
- # It's important that RulesForRoom gets added to self._get_rules_for_room.cache
+ def _get_rules_for_room(self, room_id: str) -> "RulesForRoomData":
+ """Get the current RulesForRoomData object for the given room id"""
+ # It's important that the RulesForRoomData object gets added to self._get_rules_for_room.cache
# before any lookup methods get called on it as otherwise there may be
# a race if invalidate_all gets called (which assumes its in the cache)
- return RulesForRoom(
- self.hs,
- room_id,
- self._get_rules_for_room.cache,
- self.room_push_rule_cache_metrics,
- )
+ return RulesForRoomData()
async def _get_power_levels_and_sender_level(
self, event: EventBase, context: EventContext
@@ -282,11 +290,49 @@ def _condition_checker(
return True
+@attr.s(slots=True)
+class RulesForRoomData:
+ """The data stored in the cache by `RulesForRoom`.
+
+ We don't store `RulesForRoom` directly in the cache as we want our caches to
+ *only* include data, and not references to e.g. the data stores.
+ """
+
+ # event_id -> (user_id, state)
+ member_map = attr.ib(type=Dict[str, Tuple[str, str]], factory=dict)
+ # user_id -> rules
+ rules_by_user = attr.ib(type=Dict[str, List[Dict[str, dict]]], factory=dict)
+
+ # The last state group we updated the caches for. If the state_group of
+ # a new event comes along, we know that we can just return the cached
+ # result.
+ # On invalidation of the rules themselves (if the user changes them),
+ # we invalidate everything and set state_group to `object()`
+ state_group = attr.ib(type=Union[object, int], factory=object)
+
+ # A sequence number to keep track of when we're allowed to update the
+ # cache. We bump the sequence number when we invalidate the cache. If
+ # the sequence number changes while we're calculating stuff we should
+ # not update the cache with it.
+ sequence = attr.ib(type=int, default=0)
+
+ # A cache of user_ids that we *know* aren't interesting, e.g. user_ids
+ # owned by AS's, or remote users, etc. (I.e. users we will never need to
+ # calculate push for)
+ # These never need to be invalidated as we will never set up push for
+ # them.
+ uninteresting_user_set = attr.ib(type=Set[str], factory=set)
+
+
class RulesForRoom:
"""Caches push rules for users in a room.
This efficiently handles users joining/leaving the room by not invalidating
the entire cache for the room.
+
+ A new instance is constructed for each call to
+ `BulkPushRuleEvaluator._get_rules_for_event`, with the cached data from
+ previous calls passed in.
"""
def __init__(
@@ -295,6 +341,8 @@ class RulesForRoom:
room_id: str,
rules_for_room_cache: LruCache,
room_push_rule_cache_metrics: CacheMetric,
+ linearizer: Linearizer,
+ cached_data: RulesForRoomData,
):
"""
Args:
@@ -303,38 +351,21 @@ class RulesForRoom:
rules_for_room_cache: The cache object that caches these
RoomsForUser objects.
room_push_rule_cache_metrics: The metrics object
+ linearizer: The linearizer used to ensure only one thing mutates
+ the cache at a time. Keyed off room_id
+ cached_data: Cached data from previous calls to `self.get_rules`,
+ can be mutated.
"""
self.room_id = room_id
self.is_mine_id = hs.is_mine_id
self.store = hs.get_datastore()
self.room_push_rule_cache_metrics = room_push_rule_cache_metrics
- self.linearizer = Linearizer(name="rules_for_room")
-
- # event_id -> (user_id, state)
- self.member_map = {} # type: Dict[str, Tuple[str, str]]
- # user_id -> rules
- self.rules_by_user = {} # type: Dict[str, List[Dict[str, dict]]]
-
- # The last state group we updated the caches for. If the state_group of
- # a new event comes along, we know that we can just return the cached
- # result.
- # On invalidation of the rules themselves (if the user changes them),
- # we invalidate everything and set state_group to `object()`
- self.state_group = object()
-
- # A sequence number to keep track of when we're allowed to update the
- # cache. We bump the sequence number when we invalidate the cache. If
- # the sequence number changes while we're calculating stuff we should
- # not update the cache with it.
- self.sequence = 0
-
- # A cache of user_ids that we *know* aren't interesting, e.g. user_ids
- # owned by AS's, or remote users, etc. (I.e. users we will never need to
- # calculate push for)
- # These never need to be invalidated as we will never set up push for
- # them.
- self.uninteresting_user_set = set() # type: Set[str]
+ # Used to ensure only one thing mutates the cache at a time. Keyed off
+ # room_id.
+ self.linearizer = linearizer
+
+ self.data = cached_data
# We need to be clever on the invalidating caches callbacks, as
# otherwise the invalidation callback holds a reference to the object,
@@ -352,25 +383,25 @@ class RulesForRoom:
"""
state_group = context.state_group
- if state_group and self.state_group == state_group:
+ if state_group and self.data.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits()
- return self.rules_by_user
+ return self.data.rules_by_user
- with (await self.linearizer.queue(())):
- if state_group and self.state_group == state_group:
+ with (await self.linearizer.queue(self.room_id)):
+ if state_group and self.data.state_group == state_group:
logger.debug("Using cached rules for %r", self.room_id)
self.room_push_rule_cache_metrics.inc_hits()
- return self.rules_by_user
+ return self.data.rules_by_user
self.room_push_rule_cache_metrics.inc_misses()
ret_rules_by_user = {}
missing_member_event_ids = {}
- if state_group and self.state_group == context.prev_group:
+ if state_group and self.data.state_group == context.prev_group:
# If we have a simple delta then we can reuse most of the previous
# results.
- ret_rules_by_user = self.rules_by_user
+ ret_rules_by_user = self.data.rules_by_user
current_state_ids = context.delta_ids
push_rules_delta_state_cache_metric.inc_hits()
@@ -393,24 +424,24 @@ class RulesForRoom:
if typ != EventTypes.Member:
continue
- if user_id in self.uninteresting_user_set:
+ if user_id in self.data.uninteresting_user_set:
continue
if not self.is_mine_id(user_id):
- self.uninteresting_user_set.add(user_id)
+ self.data.uninteresting_user_set.add(user_id)
continue
if self.store.get_if_app_services_interested_in_user(user_id):
- self.uninteresting_user_set.add(user_id)
+ self.data.uninteresting_user_set.add(user_id)
continue
event_id = current_state_ids[key]
- res = self.member_map.get(event_id, None)
+ res = self.data.member_map.get(event_id, None)
if res:
user_id, state = res
if state == Membership.JOIN:
- rules = self.rules_by_user.get(user_id, None)
+ rules = self.data.rules_by_user.get(user_id, None)
if rules:
ret_rules_by_user[user_id] = rules
continue
@@ -430,7 +461,7 @@ class RulesForRoom:
else:
# The push rules didn't change but lets update the cache anyway
self.update_cache(
- self.sequence,
+ self.data.sequence,
members={}, # There were no membership changes
rules_by_user=ret_rules_by_user,
state_group=state_group,
@@ -461,7 +492,7 @@ class RulesForRoom:
for. Used when updating the cache.
event: The event we are currently computing push rules for.
"""
- sequence = self.sequence
+ sequence = self.data.sequence
rows = await self.store.get_membership_from_event_ids(member_event_ids.values())
@@ -501,23 +532,11 @@ class RulesForRoom:
self.update_cache(sequence, members, ret_rules_by_user, state_group)
- def invalidate_all(self) -> None:
- # Note: Don't hand this function directly to an invalidation callback
- # as it keeps a reference to self and will stop this instance from being
- # GC'd if it gets dropped from the rules_to_user cache. Instead use
- # `self.invalidate_all_cb`
- logger.debug("Invalidating RulesForRoom for %r", self.room_id)
- self.sequence += 1
- self.state_group = object()
- self.member_map = {}
- self.rules_by_user = {}
- push_rules_invalidation_counter.inc()
-
def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
- if sequence == self.sequence:
- self.member_map.update(members)
- self.rules_by_user = rules_by_user
- self.state_group = state_group
+ if sequence == self.data.sequence:
+ self.data.member_map.update(members)
+ self.data.rules_by_user = rules_by_user
+ self.data.state_group = state_group
@attr.attrs(slots=True, frozen=True)
@@ -535,6 +554,10 @@ class _Invalidation:
room_id = attr.ib(type=str)
def __call__(self) -> None:
- rules = self.cache.get(self.room_id, None, update_metrics=False)
- if rules:
- rules.invalidate_all()
+ rules_data = self.cache.get(self.room_id, None, update_metrics=False)
+ if rules_data:
+ rules_data.sequence += 1
+ rules_data.state_group = object()
+ rules_data.member_map = {}
+ rules_data.rules_by_user = {}
+ push_rules_invalidation_counter.inc()
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index bd8513cd43..2a8532f8c1 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -23,8 +23,11 @@ from typing import (
Optional,
Set,
Tuple,
+ Union,
)
+import attr
+
from synapse.api.constants import EventTypes, Membership
from synapse.events import EventBase
from synapse.events.snapshot import EventContext
@@ -43,7 +46,7 @@ from synapse.storage.roommember import (
ProfileInfo,
RoomsForUser,
)
-from synapse.types import PersistedEventPosition, get_domain_from_id
+from synapse.types import PersistedEventPosition, StateMap, get_domain_from_id
from synapse.util.async_helpers import Linearizer
from synapse.util.caches import intern_string
from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -63,6 +66,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
def __init__(self, database: DatabasePool, db_conn, hs):
super().__init__(database, db_conn, hs)
+ # Used by `_get_joined_hosts` to ensure only one thing mutates the cache
+ # at a time. Keyed by room_id.
+ self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
+
# Is the current_state_events.membership up to date? Or is the
# background update still running?
self._current_state_events_membership_up_to_date = False
@@ -740,19 +747,82 @@ class RoomMemberWorkerStore(EventsWorkerStore):
@cached(num_args=2, max_entries=10000, iterable=True)
async def _get_joined_hosts(
- self, room_id, state_group, current_state_ids, state_entry
- ):
- # We don't use `state_group`, its there so that we can cache based
- # on it. However, its important that its never None, since two current_state's
- # with a state_group of None are likely to be different.
+ self,
+ room_id: str,
+ state_group: int,
+ current_state_ids: StateMap[str],
+ state_entry: "_StateCacheEntry",
+ ) -> FrozenSet[str]:
+ # We don't use `state_group`, its there so that we can cache based on
+ # it. However, its important that its never None, since two
+ # current_state's with a state_group of None are likely to be different.
+ #
+ # The `state_group` must match the `state_entry.state_group` (if not None).
assert state_group is not None
-
+ assert state_entry.state_group is None or state_entry.state_group == state_group
+
+ # We use a secondary cache of previous work to allow us to build up the
+ # joined hosts for the given state group based on previous state groups.
+ #
+ # We cache one object per room containing the results of the last state
+ # group we got joined hosts for. The idea is that generally
+ # `get_joined_hosts` is called with the "current" state group for the
+ # room, and so consecutive calls will be for consecutive state groups
+ # which point to the previous state group.
cache = await self._get_joined_hosts_cache(room_id)
- return await cache.get_destinations(state_entry)
+
+ # If the state group in the cache matches, we already have the data we need.
+ if state_entry.state_group == cache.state_group:
+ return frozenset(cache.hosts_to_joined_users)
+
+ # Since we'll mutate the cache we need to lock.
+ with (await self._joined_host_linearizer.queue(room_id)):
+ if state_entry.state_group == cache.state_group:
+ # Same state group, so nothing to do. We've already checked for
+ # this above, but the cache may have changed while waiting on
+ # the lock.
+ pass
+ elif state_entry.prev_group == cache.state_group:
+ # The cached work is for the previous state group, so we work out
+ # the delta.
+ for (typ, state_key), event_id in state_entry.delta_ids.items():
+ if typ != EventTypes.Member:
+ continue
+
+ host = intern_string(get_domain_from_id(state_key))
+ user_id = state_key
+ known_joins = cache.hosts_to_joined_users.setdefault(host, set())
+
+ event = await self.get_event(event_id)
+ if event.membership == Membership.JOIN:
+ known_joins.add(user_id)
+ else:
+ known_joins.discard(user_id)
+
+ if not known_joins:
+ cache.hosts_to_joined_users.pop(host, None)
+ else:
+ # The cache doesn't match the state group or prev state group,
+ # so we calculate the result from first principles.
+ joined_users = await self.get_joined_users_from_state(
+ room_id, state_entry
+ )
+
+ cache.hosts_to_joined_users = {}
+ for user_id in joined_users:
+ host = intern_string(get_domain_from_id(user_id))
+ cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
+
+ if state_entry.state_group:
+ cache.state_group = state_entry.state_group
+ else:
+ cache.state_group = object()
+
+ return frozenset(cache.hosts_to_joined_users)
@cached(max_entries=10000)
def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
- return _JoinedHostsCache(self, room_id)
+ return _JoinedHostsCache()
@cached(num_args=2)
async def did_forget(self, user_id: str, room_id: str) -> bool:
@@ -1062,71 +1132,18 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
await self.db_pool.runInteraction("forget_membership", f)
+@attr.s(slots=True)
class _JoinedHostsCache:
- """Cache for joined hosts in a room that is optimised to handle updates
- via state deltas.
- """
-
- def __init__(self, store, room_id):
- self.store = store
- self.room_id = room_id
+ """The cached data used by the `_get_joined_hosts_cache`."""
- self.hosts_to_joined_users = {}
+ # Dict of host to the set of their users in the room at the state group.
+ hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict)
- self.state_group = object()
-
- self.linearizer = Linearizer("_JoinedHostsCache")
-
- self._len = 0
-
- async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]:
- """Get set of destinations for a state entry
-
- Args:
- state_entry
-
- Returns:
- The destinations as a set.
- """
- if state_entry.state_group == self.state_group:
- return frozenset(self.hosts_to_joined_users)
-
- with (await self.linearizer.queue(())):
- if state_entry.state_group == self.state_group:
- pass
- elif state_entry.prev_group == self.state_group:
- for (typ, state_key), event_id in state_entry.delta_ids.items():
- if typ != EventTypes.Member:
- continue
-
- host = intern_string(get_domain_from_id(state_key))
- user_id = state_key
- known_joins = self.hosts_to_joined_users.setdefault(host, set())
-
- event = await self.store.get_event(event_id)
- if event.membership == Membership.JOIN:
- known_joins.add(user_id)
- else:
- known_joins.discard(user_id)
-
- if not known_joins:
- self.hosts_to_joined_users.pop(host, None)
- else:
- joined_users = await self.store.get_joined_users_from_state(
- self.room_id, state_entry
- )
-
- self.hosts_to_joined_users = {}
- for user_id in joined_users:
- host = intern_string(get_domain_from_id(user_id))
- self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
-
- if state_entry.state_group:
- self.state_group = state_entry.state_group
- else:
- self.state_group = object()
- self._len = sum(len(v) for v in self.hosts_to_joined_users.values())
- return frozenset(self.hosts_to_joined_users)
+ # The state group `hosts_to_joined_users` is derived from. Will be an object
+ # if the instance is newly created or if the state is not based on a state
+ # group. (An object is used as a sentinel value to ensure that it never is
+ # equal to anything else).
+ state_group = attr.ib(type=Union[object, int], factory=object)
def __len__(self):
- return self._len
+ return sum(len(v) for v in self.hosts_to_joined_users.values())
|