summary refs log tree commit diff
path: root/synapse/push
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/push')
-rw-r--r--synapse/push/action_generator.py2
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py23
2 files changed, 17 insertions, 8 deletions
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index 0658497d9b..fe09d50d55 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -24,7 +24,7 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class ActionGenerator:
+class ActionGenerator(object):
     def __init__(self, hs):
         self.hs = hs
         self.clock = hs.get_clock()
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index 0158026915..760d567ca1 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -24,6 +24,8 @@ from synapse.api.constants import EventTypes, Membership
 from synapse.util.caches.descriptors import cached
 from synapse.util.async import Linearizer
 
+from collections import namedtuple
+
 
 logger = logging.getLogger(__name__)
 
@@ -31,7 +33,7 @@ logger = logging.getLogger(__name__)
 rules_by_room = {}
 
 
-class BulkPushRuleEvaluator:
+class BulkPushRuleEvaluator(object):
     """Calculates the outcome of push rules for an event for all users in the
     room at once.
     """
@@ -204,12 +206,7 @@ class RulesForRoom(object):
         # To get around this we pass a function that on invalidations looks ups
         # the RoomsForUser entry in the cache, rather than keeping a reference
         # to self around in the callback.
-        def invalidate_all_cb():
-            rules = rules_for_room_cache.get(room_id, update_metrics=False)
-            if rules:
-                rules.invalidate_all()
-
-        self.invalidate_all_cb = invalidate_all_cb
+        self.invalidate_all_cb = _Invalidation(rules_for_room_cache, room_id)
 
     @defer.inlineCallbacks
     def get_rules(self, context):
@@ -347,3 +344,15 @@ class RulesForRoom(object):
             self.member_map.update(members)
             self.rules_by_user = rules_by_user
             self.state_group = state_group
+
+
+class _Invalidation(namedtuple("_Invalidation", ("cache", "room_id"))):
+    # We rely on _CacheContext implementing __eq__ and __hash__ sensibly,
+    # which namedtuple does for us (i.e. two _CacheContext are the same if
+    # their caches and keys match). This is important in particular to
+    # dedupe when we add callbacks to lru cache nodes, otherwise the number
+    # of callbacks would grow.
+    def __call__(self):
+        rules = self.cache.get(self.room_id, None, update_metrics=False)
+        if rules:
+            rules.invalidate_all()