summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/9845.misc1
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py161
-rw-r--r--synapse/storage/databases/main/roommember.py161
3 files changed, 182 insertions, 141 deletions
diff --git a/changelog.d/9845.misc b/changelog.d/9845.misc
new file mode 100644
index 0000000000..875dd6d131
--- /dev/null
+++ b/changelog.d/9845.misc
@@ -0,0 +1 @@
+Only store the raw data in the in-memory caches, rather than objects that include references to e.g. the data stores.
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 50b470c310..350646f458 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -106,6 +106,10 @@ class BulkPushRuleEvaluator:
         self.store = hs.get_datastore()
         self.auth = hs.get_auth()
 
+        # Used by `RulesForRoom` to ensure only one thing mutates the cache at a
+        # time. Keyed off room_id.
+        self._rules_linearizer = Linearizer(name="rules_for_room")
+
         self.room_push_rule_cache_metrics = register_cache(
             "cache",
             "room_push_rule_cache",
@@ -123,7 +127,16 @@ class BulkPushRuleEvaluator:
             dict of user_id -> push_rules
         """
         room_id = event.room_id
-        rules_for_room = self._get_rules_for_room(room_id)
+
+        rules_for_room_data = self._get_rules_for_room(room_id)
+        rules_for_room = RulesForRoom(
+            hs=self.hs,
+            room_id=room_id,
+            rules_for_room_cache=self._get_rules_for_room.cache,
+            room_push_rule_cache_metrics=self.room_push_rule_cache_metrics,
+            linearizer=self._rules_linearizer,
+            cached_data=rules_for_room_data,
+        )
 
         rules_by_user = await rules_for_room.get_rules(event, context)
 
@@ -142,17 +155,12 @@ class BulkPushRuleEvaluator:
         return rules_by_user
 
     @lru_cache()
-    def _get_rules_for_room(self, room_id: str) -> "RulesForRoom":
-        """Get the current RulesForRoom object for the given room id"""
-        # It's important that RulesForRoom gets added to self._get_rules_for_room.cache
+    def _get_rules_for_room(self, room_id: str) -> "RulesForRoomData":
+        """Get the current RulesForRoomData object for the given room id"""
+        # It's important that the RulesForRoomData object gets added to self._get_rules_for_room.cache
         # before any lookup methods get called on it as otherwise there may be
         # a race if invalidate_all gets called (which assumes its in the cache)
-        return RulesForRoom(
-            self.hs,
-            room_id,
-            self._get_rules_for_room.cache,
-            self.room_push_rule_cache_metrics,
-        )
+        return RulesForRoomData()
 
     async def _get_power_levels_and_sender_level(
         self, event: EventBase, context: EventContext
@@ -282,11 +290,49 @@ def _condition_checker(
     return True
 
 
+@attr.s(slots=True)
+class RulesForRoomData:
+    """The data stored in the cache by `RulesForRoom`.
+
+    We don't store `RulesForRoom` directly in the cache as we want our caches to
+    *only* include data, and not references to e.g. the data stores.
+    """
+
+    # event_id -> (user_id, state)
+    member_map = attr.ib(type=Dict[str, Tuple[str, str]], factory=dict)
+    # user_id -> rules
+    rules_by_user = attr.ib(type=Dict[str, List[Dict[str, dict]]], factory=dict)
+
+    # The last state group we updated the caches for. If the state_group of
+    # a new event comes along, we know that we can just return the cached
+    # result.
+    # On invalidation of the rules themselves (if the user changes them),
+    # we invalidate everything and set state_group to `object()`
+    state_group = attr.ib(type=Union[object, int], factory=object)
+
+    # A sequence number to keep track of when we're allowed to update the
+    # cache. We bump the sequence number when we invalidate the cache. If
+    # the sequence number changes while we're calculating stuff we should
+    # not update the cache with it.
+    sequence = attr.ib(type=int, default=0)
+
+    # A cache of user_ids that we *know* aren't interesting, e.g. user_ids
+    # owned by AS's, or remote users, etc. (I.e. users we will never need to
+    # calculate push for)
+    # These never need to be invalidated as we will never set up push for
+    # them.
+    uninteresting_user_set = attr.ib(type=Set[str], factory=set)
+
+
 class RulesForRoom:
     """Caches push rules for users in a room.
 
     This efficiently handles users joining/leaving the room by not invalidating
     the entire cache for the room.
+
+    A new instance is constructed for each call to
+    `BulkPushRuleEvaluator._get_rules_for_event`, with the cached data from
+    previous calls passed in.
     """
 
     def __init__(
@@ -295,6 +341,8 @@ class RulesForRoom:
         room_id: str,
         rules_for_room_cache: LruCache,
         room_push_rule_cache_metrics: CacheMetric,
+        linearizer: Linearizer,
+        cached_data: RulesForRoomData,
     ):
         """
         Args:
@@ -303,38 +351,21 @@ class RulesForRoom:
             rules_for_room_cache: The cache object that caches these
                 RoomsForUser objects.
             room_push_rule_cache_metrics: The metrics object
+            linearizer: The linearizer used to ensure only one thing mutates
+                the cache at a time. Keyed off room_id
+            cached_data: Cached data from previous calls to `self.get_rules`,
+                can be mutated.
         """
         self.room_id = room_id
         self.is_mine_id = hs.is_mine_id
         self.store = hs.get_datastore()
         self.room_push_rule_cache_metrics = room_push_rule_cache_metrics
 
-        self.linearizer = Linearizer(name="rules_for_room")
-
-        # event_id -> (user_id, state)
-        self.member_map = {}  # type: Dict[str, Tuple[str, str]]
-        # user_id -> rules
-        self.rules_by_user = {}  # type: Dict[str, List[Dict[str, dict]]]
-
-        # The last state group we updated the caches for. If the state_group of
-        # a new event comes along, we know that we can just return the cached
-        # result.
-        # On invalidation of the rules themselves (if the user changes them),
-        # we invalidate everything and set state_group to `object()`
-        self.state_group = object()
-
-        # A sequence number to keep track of when we're allowed to update the
-        # cache. We bump the sequence number when we invalidate the cache. If
-        # the sequence number changes while we're calculating stuff we should
-        # not update the cache with it.
-        self.sequence = 0
-
-        # A cache of user_ids that we *know* aren't interesting, e.g. user_ids
-        # owned by AS's, or remote users, etc. (I.e. users we will never need to
-        # calculate push for)
-        # These never need to be invalidated as we will never set up push for
-        # them.
-        self.uninteresting_user_set = set()  # type: Set[str]
+        # Used to ensure only one thing mutates the cache at a time. Keyed off
+        # room_id.
+        self.linearizer = linearizer
+
+        self.data = cached_data
 
         # We need to be clever on the invalidating caches callbacks, as
         # otherwise the invalidation callback holds a reference to the object,
@@ -352,25 +383,25 @@ class RulesForRoom:
         """
         state_group = context.state_group
 
-        if state_group and self.state_group == state_group:
+        if state_group and self.data.state_group == state_group:
             logger.debug("Using cached rules for %r", self.room_id)
             self.room_push_rule_cache_metrics.inc_hits()
-            return self.rules_by_user
+            return self.data.rules_by_user
 
-        with (await self.linearizer.queue(())):
-            if state_group and self.state_group == state_group:
+        with (await self.linearizer.queue(self.room_id)):
+            if state_group and self.data.state_group == state_group:
                 logger.debug("Using cached rules for %r", self.room_id)
                 self.room_push_rule_cache_metrics.inc_hits()
-                return self.rules_by_user
+                return self.data.rules_by_user
 
             self.room_push_rule_cache_metrics.inc_misses()
 
             ret_rules_by_user = {}
             missing_member_event_ids = {}
-            if state_group and self.state_group == context.prev_group:
+            if state_group and self.data.state_group == context.prev_group:
                 # If we have a simple delta then we can reuse most of the previous
                 # results.
-                ret_rules_by_user = self.rules_by_user
+                ret_rules_by_user = self.data.rules_by_user
                 current_state_ids = context.delta_ids
 
                 push_rules_delta_state_cache_metric.inc_hits()
@@ -393,24 +424,24 @@ class RulesForRoom:
                 if typ != EventTypes.Member:
                     continue
 
-                if user_id in self.uninteresting_user_set:
+                if user_id in self.data.uninteresting_user_set:
                     continue
 
                 if not self.is_mine_id(user_id):
-                    self.uninteresting_user_set.add(user_id)
+                    self.data.uninteresting_user_set.add(user_id)
                     continue
 
                 if self.store.get_if_app_services_interested_in_user(user_id):
-                    self.uninteresting_user_set.add(user_id)
+                    self.data.uninteresting_user_set.add(user_id)
                     continue
 
                 event_id = current_state_ids[key]
 
-                res = self.member_map.get(event_id, None)
+                res = self.data.member_map.get(event_id, None)
                 if res:
                     user_id, state = res
                     if state == Membership.JOIN:
-                        rules = self.rules_by_user.get(user_id, None)
+                        rules = self.data.rules_by_user.get(user_id, None)
                         if rules:
                             ret_rules_by_user[user_id] = rules
                     continue
@@ -430,7 +461,7 @@ class RulesForRoom:
             else:
                 # The push rules didn't change but lets update the cache anyway
                 self.update_cache(
-                    self.sequence,
+                    self.data.sequence,
                     members={},  # There were no membership changes
                     rules_by_user=ret_rules_by_user,
                     state_group=state_group,
@@ -461,7 +492,7 @@ class RulesForRoom:
                 for. Used when updating the cache.
             event: The event we are currently computing push rules for.
         """
-        sequence = self.sequence
+        sequence = self.data.sequence
 
         rows = await self.store.get_membership_from_event_ids(member_event_ids.values())
 
@@ -501,23 +532,11 @@ class RulesForRoom:
 
         self.update_cache(sequence, members, ret_rules_by_user, state_group)
 
-    def invalidate_all(self) -> None:
-        # Note: Don't hand this function directly to an invalidation callback
-        # as it keeps a reference to self and will stop this instance from being
-        # GC'd if it gets dropped from the rules_to_user cache. Instead use
-        # `self.invalidate_all_cb`
-        logger.debug("Invalidating RulesForRoom for %r", self.room_id)
-        self.sequence += 1
-        self.state_group = object()
-        self.member_map = {}
-        self.rules_by_user = {}
-        push_rules_invalidation_counter.inc()
-
     def update_cache(self, sequence, members, rules_by_user, state_group) -> None:
-        if sequence == self.sequence:
-            self.member_map.update(members)
-            self.rules_by_user = rules_by_user
-            self.state_group = state_group
+        if sequence == self.data.sequence:
+            self.data.member_map.update(members)
+            self.data.rules_by_user = rules_by_user
+            self.data.state_group = state_group
 
 
 @attr.attrs(slots=True, frozen=True)
@@ -535,6 +554,10 @@ class _Invalidation:
     room_id = attr.ib(type=str)
 
     def __call__(self) -> None:
-        rules = self.cache.get(self.room_id, None, update_metrics=False)
-        if rules:
-            rules.invalidate_all()
+        rules_data = self.cache.get(self.room_id, None, update_metrics=False)
+        if rules_data:
+            rules_data.sequence += 1
+            rules_data.state_group = object()
+            rules_data.member_map = {}
+            rules_data.rules_by_user = {}
+            push_rules_invalidation_counter.inc()
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index bd8513cd43..2a8532f8c1 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -23,8 +23,11 @@ from typing import (
     Optional,
     Set,
     Tuple,
+    Union,
 )
 
+import attr
+
 from synapse.api.constants import EventTypes, Membership
 from synapse.events import EventBase
 from synapse.events.snapshot import EventContext
@@ -43,7 +46,7 @@ from synapse.storage.roommember import (
     ProfileInfo,
     RoomsForUser,
 )
-from synapse.types import PersistedEventPosition, get_domain_from_id
+from synapse.types import PersistedEventPosition, StateMap, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
@@ -63,6 +66,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
     def __init__(self, database: DatabasePool, db_conn, hs):
         super().__init__(database, db_conn, hs)
 
+        # Used by `_get_joined_hosts` to ensure only one thing mutates the cache
+        # at a time. Keyed by room_id.
+        self._joined_host_linearizer = Linearizer("_JoinedHostsCache")
+
         # Is the current_state_events.membership up to date? Or is the
         # background update still running?
         self._current_state_events_membership_up_to_date = False
@@ -740,19 +747,82 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
     @cached(num_args=2, max_entries=10000, iterable=True)
     async def _get_joined_hosts(
-        self, room_id, state_group, current_state_ids, state_entry
-    ):
-        # We don't use `state_group`, its there so that we can cache based
-        # on it. However, its important that its never None, since two current_state's
-        # with a state_group of None are likely to be different.
+        self,
+        room_id: str,
+        state_group: int,
+        current_state_ids: StateMap[str],
+        state_entry: "_StateCacheEntry",
+    ) -> FrozenSet[str]:
+        # We don't use `state_group`, its there so that we can cache based on
+        # it. However, its important that its never None, since two
+        # current_state's with a state_group of None are likely to be different.
+        #
+        # The `state_group` must match the `state_entry.state_group` (if not None).
         assert state_group is not None
-
+        assert state_entry.state_group is None or state_entry.state_group == state_group
+
+        # We use a secondary cache of previous work to allow us to build up the
+        # joined hosts for the given state group based on previous state groups.
+        #
+        # We cache one object per room containing the results of the last state
+        # group we got joined hosts for. The idea is that generally
+        # `get_joined_hosts` is called with the "current" state group for the
+        # room, and so consecutive calls will be for consecutive state groups
+        # which point to the previous state group.
         cache = await self._get_joined_hosts_cache(room_id)
-        return await cache.get_destinations(state_entry)
+
+        # If the state group in the cache matches, we already have the data we need.
+        if state_entry.state_group == cache.state_group:
+            return frozenset(cache.hosts_to_joined_users)
+
+        # Since we'll mutate the cache we need to lock.
+        with (await self._joined_host_linearizer.queue(room_id)):
+            if state_entry.state_group == cache.state_group:
+                # Same state group, so nothing to do. We've already checked for
+                # this above, but the cache may have changed while waiting on
+                # the lock.
+                pass
+            elif state_entry.prev_group == cache.state_group:
+                # The cached work is for the previous state group, so we work out
+                # the delta.
+                for (typ, state_key), event_id in state_entry.delta_ids.items():
+                    if typ != EventTypes.Member:
+                        continue
+
+                    host = intern_string(get_domain_from_id(state_key))
+                    user_id = state_key
+                    known_joins = cache.hosts_to_joined_users.setdefault(host, set())
+
+                    event = await self.get_event(event_id)
+                    if event.membership == Membership.JOIN:
+                        known_joins.add(user_id)
+                    else:
+                        known_joins.discard(user_id)
+
+                        if not known_joins:
+                            cache.hosts_to_joined_users.pop(host, None)
+            else:
+                # The cache doesn't match the state group or prev state group,
+                # so we calculate the result from first principles.
+                joined_users = await self.get_joined_users_from_state(
+                    room_id, state_entry
+                )
+
+                cache.hosts_to_joined_users = {}
+                for user_id in joined_users:
+                    host = intern_string(get_domain_from_id(user_id))
+                    cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
+
+            if state_entry.state_group:
+                cache.state_group = state_entry.state_group
+            else:
+                cache.state_group = object()
+
+        return frozenset(cache.hosts_to_joined_users)
 
     @cached(max_entries=10000)
     def _get_joined_hosts_cache(self, room_id: str) -> "_JoinedHostsCache":
-        return _JoinedHostsCache(self, room_id)
+        return _JoinedHostsCache()
 
     @cached(num_args=2)
     async def did_forget(self, user_id: str, room_id: str) -> bool:
@@ -1062,71 +1132,18 @@ class RoomMemberStore(RoomMemberWorkerStore, RoomMemberBackgroundUpdateStore):
         await self.db_pool.runInteraction("forget_membership", f)
 
 
+@attr.s(slots=True)
 class _JoinedHostsCache:
-    """Cache for joined hosts in a room that is optimised to handle updates
-    via state deltas.
-    """
-
-    def __init__(self, store, room_id):
-        self.store = store
-        self.room_id = room_id
+    """The cached data used by the `_get_joined_hosts_cache`."""
 
-        self.hosts_to_joined_users = {}
+    # Dict of host to the set of their users in the room at the state group.
+    hosts_to_joined_users = attr.ib(type=Dict[str, Set[str]], factory=dict)
 
-        self.state_group = object()
-
-        self.linearizer = Linearizer("_JoinedHostsCache")
-
-        self._len = 0
-
-    async def get_destinations(self, state_entry: "_StateCacheEntry") -> Set[str]:
-        """Get set of destinations for a state entry
-
-        Args:
-            state_entry
-
-        Returns:
-            The destinations as a set.
-        """
-        if state_entry.state_group == self.state_group:
-            return frozenset(self.hosts_to_joined_users)
-
-        with (await self.linearizer.queue(())):
-            if state_entry.state_group == self.state_group:
-                pass
-            elif state_entry.prev_group == self.state_group:
-                for (typ, state_key), event_id in state_entry.delta_ids.items():
-                    if typ != EventTypes.Member:
-                        continue
-
-                    host = intern_string(get_domain_from_id(state_key))
-                    user_id = state_key
-                    known_joins = self.hosts_to_joined_users.setdefault(host, set())
-
-                    event = await self.store.get_event(event_id)
-                    if event.membership == Membership.JOIN:
-                        known_joins.add(user_id)
-                    else:
-                        known_joins.discard(user_id)
-
-                        if not known_joins:
-                            self.hosts_to_joined_users.pop(host, None)
-            else:
-                joined_users = await self.store.get_joined_users_from_state(
-                    self.room_id, state_entry
-                )
-
-                self.hosts_to_joined_users = {}
-                for user_id in joined_users:
-                    host = intern_string(get_domain_from_id(user_id))
-                    self.hosts_to_joined_users.setdefault(host, set()).add(user_id)
-
-            if state_entry.state_group:
-                self.state_group = state_entry.state_group
-            else:
-                self.state_group = object()
-            self._len = sum(len(v) for v in self.hosts_to_joined_users.values())
-        return frozenset(self.hosts_to_joined_users)
+    # The state group `hosts_to_joined_users` is derived from. Will be an object
+    # if the instance is newly created or if the state is not based on a state
+    # group. (An object is used as a sentinel value to ensure that it never is
+    # equal to anything else).
+    state_group = attr.ib(type=Union[object, int], factory=object)
 
     def __len__(self):
-        return self._len
+        return sum(len(v) for v in self.hosts_to_joined_users.values())