summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--UPGRADE.rst2
-rw-r--r--synapse/handlers/message.py5
-rw-r--r--synapse/push/action_generator.py12
-rw-r--r--synapse/push/bulk_push_rule_evaluator.py282
-rw-r--r--synapse/storage/appservice.py7
-rw-r--r--synapse/storage/push_rule.py2
-rw-r--r--synapse/storage/state.py40
-rw-r--r--synapse/util/caches/dictionary_cache.py57
-rw-r--r--tests/util/test_dict_cache.py2
9 files changed, 317 insertions, 92 deletions
diff --git a/UPGRADE.rst b/UPGRADE.rst
index 6164df8833..62b22e9108 100644
--- a/UPGRADE.rst
+++ b/UPGRADE.rst
@@ -33,7 +33,7 @@ To check whether your update was sucessfull, run:
 
 .. code:: bash
 
-	 # replace your.server.domain with ther domain of your synaspe homeserver
+	 # replace your.server.domain with ther domain of your synapse homeserver
 	 curl https://<your.server.domain>/_matrix/federation/v1/version 
 
 So for the Matrix.org HS server the URL would be: https://matrix.org/_matrix/federation/v1/version.
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index 196925edad..ba8776f288 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -54,6 +54,8 @@ class MessageHandler(BaseHandler):
         # This is to stop us from diverging history *too* much.
         self.limiter = Limiter(max_count=5)
 
+        self.action_generator = ActionGenerator(self.hs)
+
     @defer.inlineCallbacks
     def purge_history(self, room_id, event_id):
         event = yield self.store.get_event(event_id)
@@ -590,8 +592,7 @@ class MessageHandler(BaseHandler):
                 "Changing the room create event is forbidden",
             )
 
-        action_generator = ActionGenerator(self.hs)
-        yield action_generator.handle_push_actions_for_event(
+        yield self.action_generator.handle_push_actions_for_event(
             event, context
         )
 
diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py
index 3f75d3f921..0658497d9b 100644
--- a/synapse/push/action_generator.py
+++ b/synapse/push/action_generator.py
@@ -15,7 +15,7 @@
 
 from twisted.internet import defer
 
-from .bulk_push_rule_evaluator import evaluator_for_event
+from .bulk_push_rule_evaluator import BulkPushRuleEvaluator
 
 from synapse.util.metrics import Measure
 
@@ -29,6 +29,7 @@ class ActionGenerator:
         self.hs = hs
         self.clock = hs.get_clock()
         self.store = hs.get_datastore()
+        self.bulk_evaluator = BulkPushRuleEvaluator(hs)
         # really we want to get all user ids and all profile tags too,
         # since we want the actions for each profile tag for every user and
         # also actions for a client with no profile tag for each user.
@@ -38,16 +39,11 @@ class ActionGenerator:
 
     @defer.inlineCallbacks
     def handle_push_actions_for_event(self, event, context):
-        with Measure(self.clock, "evaluator_for_event"):
-            bulk_evaluator = yield evaluator_for_event(
-                event, self.hs, self.store, context
-            )
-
         with Measure(self.clock, "action_for_event_by_user"):
-            actions_by_user = yield bulk_evaluator.action_for_event_by_user(
+            actions_by_user = yield self.bulk_evaluator.action_for_event_by_user(
                 event, context
             )
 
         context.push_actions = [
-            (uid, actions) for uid, actions in actions_by_user.items()
+            (uid, actions) for uid, actions in actions_by_user.iteritems()
         ]
diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py
index f943ff640f..5b1f9a1c2d 100644
--- a/synapse/push/bulk_push_rule_evaluator.py
+++ b/synapse/push/bulk_push_rule_evaluator.py
@@ -19,60 +19,81 @@ from twisted.internet import defer
 
 from .push_rule_evaluator import PushRuleEvaluatorForEvent
 
-from synapse.api.constants import EventTypes
 from synapse.visibility import filter_events_for_clients_context
+from synapse.api.constants import EventTypes, Membership
+from synapse.util.caches.descriptors import cached
+from synapse.util.async import Linearizer
 
 
 logger = logging.getLogger(__name__)
 
 
-@defer.inlineCallbacks
-def evaluator_for_event(event, hs, store, context):
-    rules_by_user = yield store.bulk_get_push_rules_for_room(
-        event, context
-    )
-
-    # if this event is an invite event, we may need to run rules for the user
-    # who's been invited, otherwise they won't get told they've been invited
-    if event.type == 'm.room.member' and event.content['membership'] == 'invite':
-        invited_user = event.state_key
-        if invited_user and hs.is_mine_id(invited_user):
-            has_pusher = yield store.user_has_pusher(invited_user)
-            if has_pusher:
-                rules_by_user = dict(rules_by_user)
-                rules_by_user[invited_user] = yield store.get_push_rules_for_user(
-                    invited_user
-                )
-
-    defer.returnValue(BulkPushRuleEvaluator(
-        event.room_id, rules_by_user, store
-    ))
+rules_by_room = {}
 
 
 class BulkPushRuleEvaluator:
+    """Calculates the outcome of push rules for an event for all users in the
+    room at once.
     """
-    Runs push rules for all users in a room.
-    This is faster than running PushRuleEvaluator for each user because it
-    fetches all the rules for all the users in one (batched) db query
-    rather than doing multiple queries per-user. It currently uses
-    the same logic to run the actual rules, but could be optimised further
-    (see https://matrix.org/jira/browse/SYN-562)
-    """
-    def __init__(self, room_id, rules_by_user, store):
-        self.room_id = room_id
-        self.rules_by_user = rules_by_user
-        self.store = store
+
+    def __init__(self, hs):
+        self.hs = hs
+        self.store = hs.get_datastore()
+
+    @defer.inlineCallbacks
+    def _get_rules_for_event(self, event, context):
+        """This gets the rules for all users in the room at the time of the event,
+        as well as the push rules for the invitee if the event is an invite.
+
+        Returns:
+            dict of user_id -> push_rules
+        """
+        room_id = event.room_id
+        rules_for_room = self._get_rules_for_room(room_id)
+
+        rules_by_user = yield rules_for_room.get_rules(context)
+
+        # if this event is an invite event, we may need to run rules for the user
+        # who's been invited, otherwise they won't get told they've been invited
+        if event.type == 'm.room.member' and event.content['membership'] == 'invite':
+            invited = event.state_key
+            if invited and self.hs.is_mine_id(invited):
+                has_pusher = yield self.store.user_has_pusher(invited)
+                if has_pusher:
+                    rules_by_user = dict(rules_by_user)
+                    rules_by_user[invited] = yield self.store.get_push_rules_for_user(
+                        invited
+                    )
+
+        defer.returnValue(rules_by_user)
+
+    @cached(max_entries=10000)
+    def _get_rules_for_room(self, room_id):
+        """Get the current RulesForRoom object for the given room id
+
+        Returns:
+            RulesForRoom
+        """
+        # It's important that RulesForRoom gets added to self._get_rules_for_room.cache
+        # before any lookup methods get called on it as otherwise there may be
+        # a race if invalidate_all gets called (which assumes its in the cache)
+        return RulesForRoom(self.hs, room_id, self._get_rules_for_room.cache)
 
     @defer.inlineCallbacks
     def action_for_event_by_user(self, event, context):
+        """Given an event and context, evaluate the push rules and return
+        the results
+
+        Returns:
+            dict of user_id -> action
+        """
+        rules_by_user = yield self._get_rules_for_event(event, context)
         actions_by_user = {}
 
         # None of these users can be peeking since this list of users comes
         # from the set of users in the room, so we know for sure they're all
         # actually in the room.
-        user_tuples = [
-            (u, False) for u in self.rules_by_user.keys()
-        ]
+        user_tuples = [(u, False) for u in rules_by_user]
 
         filtered_by_user = yield filter_events_for_clients_context(
             self.store, user_tuples, [event], {event.event_id: context}
@@ -86,7 +107,7 @@ class BulkPushRuleEvaluator:
 
         condition_cache = {}
 
-        for uid, rules in self.rules_by_user.items():
+        for uid, rules in rules_by_user.iteritems():
             display_name = None
             profile_info = room_members.get(uid)
             if profile_info:
@@ -138,3 +159,190 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache):
             return False
 
     return True
+
+
+class RulesForRoom(object):
+    """Caches push rules for users in a room.
+
+    This efficiently handles users joining/leaving the room by not invalidating
+    the entire cache for the room.
+    """
+
+    def __init__(self, hs, room_id, rules_for_room_cache):
+        """
+        Args:
+            hs (HomeServer)
+            room_id (str)
+            rules_for_room_cache(Cache): The cache object that caches these
+                RoomsForUser objects.
+        """
+        self.room_id = room_id
+        self.is_mine_id = hs.is_mine_id
+        self.store = hs.get_datastore()
+
+        self.linearizer = Linearizer(name="rules_for_room")
+
+        self.member_map = {}  # event_id -> (user_id, state)
+        self.rules_by_user = {}  # user_id -> rules
+
+        # The last state group we updated the caches for. If the state_group of
+        # a new event comes along, we know that we can just return the cached
+        # result.
+        # On invalidation of the rules themselves (if the user changes them),
+        # we invalidate everything and set state_group to `object()`
+        self.state_group = object()
+
+        # A sequence number to keep track of when we're allowed to update the
+        # cache. We bump the sequence number when we invalidate the cache. If
+        # the sequence number changes while we're calculating stuff we should
+        # not update the cache with it.
+        self.sequence = 0
+
+        # We need to be clever on the invalidating caches callbacks, as
+        # otherwise the invalidation callback holds a reference to the object,
+        # potentially causing it to leak.
+        # To get around this we pass a function that on invalidations looks ups
+        # the RoomsForUser entry in the cache, rather than keeping a reference
+        # to self around in the callback.
+        def invalidate_all_cb():
+            rules = rules_for_room_cache.get(room_id, update_metrics=False)
+            if rules:
+                rules.invalidate_all()
+
+        self.invalidate_all_cb = invalidate_all_cb
+
+    @defer.inlineCallbacks
+    def get_rules(self, context):
+        """Given an event context return the rules for all users who are
+        currently in the room.
+        """
+        state_group = context.state_group
+
+        with (yield self.linearizer.queue(())):
+            if state_group and self.state_group == state_group:
+                defer.returnValue(self.rules_by_user)
+
+            ret_rules_by_user = {}
+            missing_member_event_ids = {}
+            if state_group and self.state_group == context.prev_group:
+                # If we have a simple delta then we can reuse most of the previous
+                # results.
+                ret_rules_by_user = self.rules_by_user
+                current_state_ids = context.delta_ids
+            else:
+                current_state_ids = context.current_state_ids
+
+            # Loop through to see which member events we've seen and have rules
+            # for and which we need to fetch
+            for key, event_id in current_state_ids.iteritems():
+                if key[0] != EventTypes.Member:
+                    continue
+
+                res = self.member_map.get(event_id, None)
+                if res:
+                    user_id, state = res
+                    if state == Membership.JOIN:
+                        rules = self.rules_by_user.get(user_id, None)
+                        if rules:
+                            ret_rules_by_user[user_id] = rules
+                    continue
+
+                user_id = key[1]
+                if not self.is_mine_id(user_id):
+                    continue
+
+                if self.store.get_if_app_services_interested_in_user(
+                    user_id, exclusive=True
+                ):
+                    continue
+
+                # If a user has left a room we remove their push rule. If they
+                # joined then we readd it later in _update_rules_with_member_event_ids
+                ret_rules_by_user.pop(user_id, None)
+                missing_member_event_ids[user_id] = event_id
+
+            if missing_member_event_ids:
+                # If we have some memebr events we haven't seen, look them up
+                # and fetch push rules for them if appropriate.
+                yield self._update_rules_with_member_event_ids(
+                    ret_rules_by_user, missing_member_event_ids, state_group
+                )
+
+        defer.returnValue(ret_rules_by_user)
+
+    @defer.inlineCallbacks
+    def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids,
+                                            state_group):
+        """Update the partially filled rules_by_user dict by fetching rules for
+        any newly joined users in the `member_event_ids` list.
+
+        Args:
+            ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets
+                updated with any new rules.
+            member_event_ids (list): List of event ids for membership events that
+                have happened since the last time we filled rules_by_user
+            state_group: The state group we are currently computing push rules
+                for. Used when updating the cache.
+        """
+        sequence = self.sequence
+
+        rows = yield self.store._simple_select_many_batch(
+            table="room_memberships",
+            column="event_id",
+            iterable=member_event_ids.values(),
+            retcols=('user_id', 'membership', 'event_id'),
+            keyvalues={},
+            batch_size=500,
+            desc="_get_rules_for_member_event_ids",
+        )
+
+        members = {
+            row["event_id"]: (row["user_id"], row["membership"])
+            for row in rows
+        }
+
+        interested_in_user_ids = set(user_id for user_id, _ in members.itervalues())
+
+        if_users_with_pushers = yield self.store.get_if_users_have_pushers(
+            interested_in_user_ids,
+            on_invalidate=self.invalidate_all_cb,
+        )
+
+        user_ids = set(
+            uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher
+        )
+
+        users_with_receipts = yield self.store.get_users_with_read_receipts_in_room(
+            self.room_id, on_invalidate=self.invalidate_all_cb,
+        )
+
+        # any users with pushers must be ours: they have pushers
+        for uid in users_with_receipts:
+            if uid in interested_in_user_ids:
+                user_ids.add(uid)
+
+        rules_by_user = yield self.store.bulk_get_push_rules(
+            user_ids, on_invalidate=self.invalidate_all_cb,
+        )
+
+        ret_rules_by_user.update(
+            item for item in rules_by_user.iteritems() if item[0] is not None
+        )
+
+        self.update_cache(sequence, members, ret_rules_by_user, state_group)
+
+    def invalidate_all(self):
+        # Note: Don't hand this function directly to an invalidation callback
+        # as it keeps a reference to self and will stop this instance from being
+        # GC'd if it gets dropped from the rules_to_user cache. Instead use
+        # `self.invalidate_all_cb`
+        self.sequence += 1
+        self.state_group = object()
+        self.member_map = {}
+        self.rules_by_user = {}
+
+    def update_cache(self, sequence, members, rules_by_user, state_group):
+        if sequence == self.sequence:
+            self.member_map.update(members)
+            self.rules_by_user = rules_by_user
+            self.state_group = state_group
diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py
index 514570561f..0e9e8d3452 100644
--- a/synapse/storage/appservice.py
+++ b/synapse/storage/appservice.py
@@ -39,12 +39,15 @@ class ApplicationServiceStore(SQLBaseStore):
     def get_app_services(self):
         return self.services_cache
 
-    def get_if_app_services_interested_in_user(self, user_id):
+    def get_if_app_services_interested_in_user(self, user_id, exclusive=False):
         """Check if the user is one associated with an app service
         """
         for service in self.services_cache:
             if service.is_interested_in_user(user_id):
-                return True
+                if exclusive:
+                    return service.is_exclusive_user(user_id)
+                else:
+                    return True
         return False
 
     def get_app_service_by_user_id(self, user_id):
diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py
index 0a819d32c5..65bad3fad6 100644
--- a/synapse/storage/push_rule.py
+++ b/synapse/storage/push_rule.py
@@ -163,7 +163,7 @@ class PushRuleStore(SQLBaseStore):
         local_users_in_room = set(
             u for u in users_in_room
             if self.hs.is_mine_id(u)
-            and not self.get_if_app_services_interested_in_user(u)
+            and not self.get_if_app_services_interested_in_user(u, exclusive=True)
         )
 
         # users in the room who have pushers need to get push rules run because
diff --git a/synapse/storage/state.py b/synapse/storage/state.py
index 85acf2ad1e..a7c3d401d4 100644
--- a/synapse/storage/state.py
+++ b/synapse/storage/state.py
@@ -563,20 +563,22 @@ class StateStore(SQLBaseStore):
                 where a `state_key` of `None` matches all state_keys for the
                 `type`.
         """
-        is_all, state_dict_ids = self._state_group_cache.get(group)
+        is_all, known_absent, state_dict_ids = self._state_group_cache.get(group)
 
         type_to_key = {}
         missing_types = set()
+
         for typ, state_key in types:
+            key = (typ, state_key)
             if state_key is None:
                 type_to_key[typ] = None
-                missing_types.add((typ, state_key))
+                missing_types.add(key)
             else:
                 if type_to_key.get(typ, object()) is not None:
                     type_to_key.setdefault(typ, set()).add(state_key)
 
-                if (typ, state_key) not in state_dict_ids:
-                    missing_types.add((typ, state_key))
+                if key not in state_dict_ids and key not in known_absent:
+                    missing_types.add(key)
 
         sentinel = object()
 
@@ -590,7 +592,7 @@ class StateStore(SQLBaseStore):
                 return True
             return False
 
-        got_all = not (missing_types or types is None)
+        got_all = is_all or not missing_types
 
         return {
             k: v for k, v in state_dict_ids.iteritems()
@@ -607,7 +609,7 @@ class StateStore(SQLBaseStore):
         Args:
             group: The state group to lookup
         """
-        is_all, state_dict_ids = self._state_group_cache.get(group)
+        is_all, _, state_dict_ids = self._state_group_cache.get(group)
 
         return state_dict_ids, is_all
 
@@ -624,7 +626,7 @@ class StateStore(SQLBaseStore):
         missing_groups = []
         if types is not None:
             for group in set(groups):
-                state_dict_ids, missing_types, got_all = self._get_some_state_from_cache(
+                state_dict_ids, _, got_all = self._get_some_state_from_cache(
                     group, types
                 )
                 results[group] = state_dict_ids
@@ -653,19 +655,7 @@ class StateStore(SQLBaseStore):
             # Now we want to update the cache with all the things we fetched
             # from the database.
             for group, group_state_dict in group_to_state_dict.iteritems():
-                if types:
-                    # We delibrately put key -> None mappings into the cache to
-                    # cache absence of the key, on the assumption that if we've
-                    # explicitly asked for some types then we will probably ask
-                    # for them again.
-                    state_dict = {
-                        (intern_string(etype), intern_string(state_key)): None
-                        for (etype, state_key) in types
-                    }
-                    state_dict.update(results[group])
-                    results[group] = state_dict
-                else:
-                    state_dict = results[group]
+                state_dict = results[group]
 
                 state_dict.update(
                     ((intern_string(k[0]), intern_string(k[1])), to_ascii(v))
@@ -677,17 +667,9 @@ class StateStore(SQLBaseStore):
                     key=group,
                     value=state_dict,
                     full=(types is None),
+                    known_absent=types,
                 )
 
-        # Remove all the entries with None values. The None values were just
-        # used for bookkeeping in the cache.
-        for group, state_dict in results.iteritems():
-            results[group] = {
-                key: event_id
-                for key, event_id in state_dict.iteritems()
-                if event_id
-            }
-
         defer.returnValue(results)
 
     def get_next_state_group(self):
diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py
index cb6933c61c..d4105822b3 100644
--- a/synapse/util/caches/dictionary_cache.py
+++ b/synapse/util/caches/dictionary_cache.py
@@ -23,7 +23,17 @@ import logging
 logger = logging.getLogger(__name__)
 
 
-class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))):
+class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))):
+    """Returned when getting an entry from the cache
+
+    Attributes:
+        full (bool): Whether the cache has the full or dict or just some keys.
+            If not full then not all requested keys will necessarily be present
+            in `value`
+        known_absent (set): Keys that were looked up in the dict and were not
+            there.
+        value (dict): The full or partial dict value
+    """
     def __len__(self):
         return len(self.value)
 
@@ -58,21 +68,31 @@ class DictionaryCache(object):
                 )
 
     def get(self, key, dict_keys=None):
+        """Fetch an entry out of the cache
+
+        Args:
+            key
+            dict_key(list): If given a set of keys then return only those keys
+                that exist in the cache.
+
+        Returns:
+            DictionaryEntry
+        """
         entry = self.cache.get(key, self.sentinel)
         if entry is not self.sentinel:
             self.metrics.inc_hits()
 
             if dict_keys is None:
-                return DictionaryEntry(entry.full, dict(entry.value))
+                return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value))
             else:
-                return DictionaryEntry(entry.full, {
+                return DictionaryEntry(entry.full, entry.known_absent, {
                     k: entry.value[k]
                     for k in dict_keys
                     if k in entry.value
                 })
 
         self.metrics.inc_misses()
-        return DictionaryEntry(False, {})
+        return DictionaryEntry(False, set(), {})
 
     def invalidate(self, key):
         self.check_thread()
@@ -87,19 +107,34 @@ class DictionaryCache(object):
         self.sequence += 1
         self.cache.clear()
 
-    def update(self, sequence, key, value, full=False):
+    def update(self, sequence, key, value, full=False, known_absent=None):
+        """Updates the entry in the cache
+
+        Args:
+            sequence
+            key
+            value (dict): The value to update the cache with.
+            full (bool): Whether the given value is the full dict, or just a
+                partial subset there of. If not full then any existing entries
+                for the key will be updated.
+            known_absent (set): Set of keys that we know don't exist in the full
+                dict.
+        """
         self.check_thread()
         if self.sequence == sequence:
             # Only update the cache if the caches sequence number matches the
             # number that the cache had before the SELECT was started (SYN-369)
+            if known_absent is None:
+                known_absent = set()
             if full:
-                self._insert(key, value)
+                self._insert(key, value, known_absent)
             else:
-                self._update_or_insert(key, value)
+                self._update_or_insert(key, value, known_absent)
 
-    def _update_or_insert(self, key, value):
-        entry = self.cache.setdefault(key, DictionaryEntry(False, {}))
+    def _update_or_insert(self, key, value, known_absent):
+        entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {}))
         entry.value.update(value)
+        entry.known_absent.update(known_absent)
 
-    def _insert(self, key, value):
-        self.cache[key] = DictionaryEntry(True, value)
+    def _insert(self, key, value, known_absent):
+        self.cache[key] = DictionaryEntry(True, known_absent, value)
diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py
index 272b71034a..bc92f85fa6 100644
--- a/tests/util/test_dict_cache.py
+++ b/tests/util/test_dict_cache.py
@@ -28,7 +28,7 @@ class DictCacheTestCase(unittest.TestCase):
         key = "test_simple_cache_hit_full"
 
         v = self.cache.get(key)
-        self.assertEqual((False, {}), v)
+        self.assertEqual((False, set(), {}), v)
 
         seq = self.cache.sequence
         test_value = {"test": "test_simple_cache_hit_full"}