summary refs log tree commit diff
path: root/synapse/push
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/push')
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py37
1 files changed, 24 insertions, 13 deletions
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index d9b5478b53..82a72dc34f 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -15,8 +15,8 @@
 # limitations under the License.
 
 import logging
-from collections import namedtuple
 
+import attr
 from prometheus_client import Counter
 
 from synapse.api.constants import EventTypes, Membership, RelationTypes
@@ -26,7 +26,8 @@ from synapse.events.snapshot import EventContext
 from synapse.state import POWER_KEY
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import register_cache
-from synapse.util.caches.descriptors import cached
+from synapse.util.caches.descriptors import lru_cache
+from synapse.util.caches.lrucache import LruCache
 
 from .push_rule_evaluator import PushRuleEvaluatorForEvent
 
@@ -120,7 +121,7 @@ class BulkPushRuleEvaluator:
             dict of user_id -> push_rules
         """
         room_id = event.room_id
-        rules_for_room = await self._get_rules_for_room(room_id)
+        rules_for_room = self._get_rules_for_room(room_id)
 
         rules_by_user = await rules_for_room.get_rules(event, context)
 
@@ -138,7 +139,7 @@ class BulkPushRuleEvaluator:
 
         return rules_by_user
 
-    @cached()
+    @lru_cache()
     def _get_rules_for_room(self, room_id):
         """Get the current RulesForRoom object for the given room id
 
@@ -275,12 +276,14 @@ class RulesForRoom:
     the entire cache for the room.
     """
 
-    def __init__(self, hs, room_id, rules_for_room_cache, room_push_rule_cache_metrics):
+    def __init__(
+        self, hs, room_id, rules_for_room_cache: LruCache, room_push_rule_cache_metrics
+    ):
         """
         Args:
             hs (HomeServer)
             room_id (str)
-            rules_for_room_cache(Cache): The cache object that caches these
+            rules_for_room_cache: The cache object that caches these
                 RoomsForUser objects.
             room_push_rule_cache_metrics (CacheMetric)
         """
@@ -489,13 +492,21 @@ class RulesForRoom:
             self.state_group = state_group
 
 
-class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
-    # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
-    # which namedtuple does for us (i.e. two _CacheContext are the same if
-    # their caches and keys match). This is important in particular to
-    # dedupe when we add callbacks to lru cache nodes, otherwise the number
-    # of callbacks would grow.
+@attr.attrs(slots=True, frozen=True)
+class _Invalidation:
+    # _Invalidation is passed as an `on_invalidate` callback to bulk_get_push_rules,
+    # which means that it it is stored on the bulk_get_push_rules cache entry. In order
+    # to ensure that we don't accumulate lots of redunant callbacks on the cache entry,
+    # we need to ensure that two _Invalidation objects are "equal" if they refer to the
+    # same `cache` and `room_id`.
+    #
+    # attrs provides suitable __hash__ and __eq__ methods, provided we remember to
+    # set `frozen=True`.
+
+    cache = attr.ib(type=LruCache)
+    room_id = attr.ib(type=str)
+
     def __call__(self):
-        rules = self.cache.get_immediate(self.room_id, None, update_metrics=False)
+        rules = self.cache.get(self.room_id, None, update_metrics=False)
         if rules:
             rules.invalidate_all()