diff options
-rw-r--r-- | UPGRADE.rst | 2 | ||||
-rw-r--r-- | synapse/handlers/message.py | 5 | ||||
-rw-r--r-- | synapse/push/action_generator.py | 12 | ||||
-rw-r--r-- | synapse/push/bulk_push_rule_evaluator.py | 282 | ||||
-rw-r--r-- | synapse/storage/appservice.py | 7 | ||||
-rw-r--r-- | synapse/storage/push_rule.py | 2 | ||||
-rw-r--r-- | synapse/storage/state.py | 40 | ||||
-rw-r--r-- | synapse/util/caches/dictionary_cache.py | 57 | ||||
-rw-r--r-- | tests/util/test_dict_cache.py | 2 |
9 files changed, 317 insertions, 92 deletions
diff --git a/UPGRADE.rst b/UPGRADE.rst index 6164df8833..62b22e9108 100644 --- a/UPGRADE.rst +++ b/UPGRADE.rst @@ -33,7 +33,7 @@ To check whether your update was sucessfull, run: .. code:: bash - # replace your.server.domain with ther domain of your synaspe homeserver + # replace your.server.domain with ther domain of your synapse homeserver curl https://<your.server.domain>/_matrix/federation/v1/version So for the Matrix.org HS server the URL would be: https://matrix.org/_matrix/federation/v1/version. diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py index 196925edad..ba8776f288 100644 --- a/synapse/handlers/message.py +++ b/synapse/handlers/message.py @@ -54,6 +54,8 @@ class MessageHandler(BaseHandler): # This is to stop us from diverging history *too* much. self.limiter = Limiter(max_count=5) + self.action_generator = ActionGenerator(self.hs) + @defer.inlineCallbacks def purge_history(self, room_id, event_id): event = yield self.store.get_event(event_id) @@ -590,8 +592,7 @@ class MessageHandler(BaseHandler): "Changing the room create event is forbidden", ) - action_generator = ActionGenerator(self.hs) - yield action_generator.handle_push_actions_for_event( + yield self.action_generator.handle_push_actions_for_event( event, context ) diff --git a/synapse/push/action_generator.py b/synapse/push/action_generator.py index 3f75d3f921..0658497d9b 100644 --- a/synapse/push/action_generator.py +++ b/synapse/push/action_generator.py @@ -15,7 +15,7 @@ from twisted.internet import defer -from .bulk_push_rule_evaluator import evaluator_for_event +from .bulk_push_rule_evaluator import BulkPushRuleEvaluator from synapse.util.metrics import Measure @@ -29,6 +29,7 @@ class ActionGenerator: self.hs = hs self.clock = hs.get_clock() self.store = hs.get_datastore() + self.bulk_evaluator = BulkPushRuleEvaluator(hs) # really we want to get all user ids and all profile tags too, # since we want the actions for each profile tag for every user and # also actions for a client with no profile tag for each user. @@ -38,16 +39,11 @@ class ActionGenerator: @defer.inlineCallbacks def handle_push_actions_for_event(self, event, context): - with Measure(self.clock, "evaluator_for_event"): - bulk_evaluator = yield evaluator_for_event( - event, self.hs, self.store, context - ) - with Measure(self.clock, "action_for_event_by_user"): - actions_by_user = yield bulk_evaluator.action_for_event_by_user( + actions_by_user = yield self.bulk_evaluator.action_for_event_by_user( event, context ) context.push_actions = [ - (uid, actions) for uid, actions in actions_by_user.items() + (uid, actions) for uid, actions in actions_by_user.iteritems() ] diff --git a/synapse/push/bulk_push_rule_evaluator.py b/synapse/push/bulk_push_rule_evaluator.py index f943ff640f..5b1f9a1c2d 100644 --- a/synapse/push/bulk_push_rule_evaluator.py +++ b/synapse/push/bulk_push_rule_evaluator.py @@ -19,60 +19,81 @@ from twisted.internet import defer from .push_rule_evaluator import PushRuleEvaluatorForEvent -from synapse.api.constants import EventTypes from synapse.visibility import filter_events_for_clients_context +from synapse.api.constants import EventTypes, Membership +from synapse.util.caches.descriptors import cached +from synapse.util.async import Linearizer logger = logging.getLogger(__name__) -@defer.inlineCallbacks -def evaluator_for_event(event, hs, store, context): - rules_by_user = yield store.bulk_get_push_rules_for_room( - event, context - ) - - # if this event is an invite event, we may need to run rules for the user - # who's been invited, otherwise they won't get told they've been invited - if event.type == 'm.room.member' and event.content['membership'] == 'invite': - invited_user = event.state_key - if invited_user and hs.is_mine_id(invited_user): - has_pusher = yield store.user_has_pusher(invited_user) - if has_pusher: - rules_by_user = dict(rules_by_user) - rules_by_user[invited_user] = yield store.get_push_rules_for_user( - invited_user - ) - - defer.returnValue(BulkPushRuleEvaluator( - event.room_id, rules_by_user, store - )) +rules_by_room = {} class BulkPushRuleEvaluator: + """Calculates the outcome of push rules for an event for all users in the + room at once. """ - Runs push rules for all users in a room. - This is faster than running PushRuleEvaluator for each user because it - fetches all the rules for all the users in one (batched) db query - rather than doing multiple queries per-user. It currently uses - the same logic to run the actual rules, but could be optimised further - (see https://matrix.org/jira/browse/SYN-562) - """ - def __init__(self, room_id, rules_by_user, store): - self.room_id = room_id - self.rules_by_user = rules_by_user - self.store = store + + def __init__(self, hs): + self.hs = hs + self.store = hs.get_datastore() + + @defer.inlineCallbacks + def _get_rules_for_event(self, event, context): + """This gets the rules for all users in the room at the time of the event, + as well as the push rules for the invitee if the event is an invite. + + Returns: + dict of user_id -> push_rules + """ + room_id = event.room_id + rules_for_room = self._get_rules_for_room(room_id) + + rules_by_user = yield rules_for_room.get_rules(context) + + # if this event is an invite event, we may need to run rules for the user + # who's been invited, otherwise they won't get told they've been invited + if event.type == 'm.room.member' and event.content['membership'] == 'invite': + invited = event.state_key + if invited and self.hs.is_mine_id(invited): + has_pusher = yield self.store.user_has_pusher(invited) + if has_pusher: + rules_by_user = dict(rules_by_user) + rules_by_user[invited] = yield self.store.get_push_rules_for_user( + invited + ) + + defer.returnValue(rules_by_user) + + @cached(max_entries=10000) + def _get_rules_for_room(self, room_id): + """Get the current RulesForRoom object for the given room id + + Returns: + RulesForRoom + """ + # It's important that RulesForRoom gets added to self._get_rules_for_room.cache + # before any lookup methods get called on it as otherwise there may be + # a race if invalidate_all gets called (which assumes its in the cache) + return RulesForRoom(self.hs, room_id, self._get_rules_for_room.cache) @defer.inlineCallbacks def action_for_event_by_user(self, event, context): + """Given an event and context, evaluate the push rules and return + the results + + Returns: + dict of user_id -> action + """ + rules_by_user = yield self._get_rules_for_event(event, context) actions_by_user = {} # None of these users can be peeking since this list of users comes # from the set of users in the room, so we know for sure they're all # actually in the room. - user_tuples = [ - (u, False) for u in self.rules_by_user.keys() - ] + user_tuples = [(u, False) for u in rules_by_user] filtered_by_user = yield filter_events_for_clients_context( self.store, user_tuples, [event], {event.event_id: context} @@ -86,7 +107,7 @@ class BulkPushRuleEvaluator: condition_cache = {} - for uid, rules in self.rules_by_user.items(): + for uid, rules in rules_by_user.iteritems(): display_name = None profile_info = room_members.get(uid) if profile_info: @@ -138,3 +159,190 @@ def _condition_checker(evaluator, conditions, uid, display_name, cache): return False return True + + +class RulesForRoom(object): + """Caches push rules for users in a room. + + This efficiently handles users joining/leaving the room by not invalidating + the entire cache for the room. + """ + + def __init__(self, hs, room_id, rules_for_room_cache): + """ + Args: + hs (HomeServer) + room_id (str) + rules_for_room_cache(Cache): The cache object that caches these + RoomsForUser objects. + """ + self.room_id = room_id + self.is_mine_id = hs.is_mine_id + self.store = hs.get_datastore() + + self.linearizer = Linearizer(name="rules_for_room") + + self.member_map = {} # event_id -> (user_id, state) + self.rules_by_user = {} # user_id -> rules + + # The last state group we updated the caches for. If the state_group of + # a new event comes along, we know that we can just return the cached + # result. + # On invalidation of the rules themselves (if the user changes them), + # we invalidate everything and set state_group to `object()` + self.state_group = object() + + # A sequence number to keep track of when we're allowed to update the + # cache. We bump the sequence number when we invalidate the cache. If + # the sequence number changes while we're calculating stuff we should + # not update the cache with it. + self.sequence = 0 + + # We need to be clever on the invalidating caches callbacks, as + # otherwise the invalidation callback holds a reference to the object, + # potentially causing it to leak. + # To get around this we pass a function that on invalidations looks ups + # the RoomsForUser entry in the cache, rather than keeping a reference + # to self around in the callback. + def invalidate_all_cb(): + rules = rules_for_room_cache.get(room_id, update_metrics=False) + if rules: + rules.invalidate_all() + + self.invalidate_all_cb = invalidate_all_cb + + @defer.inlineCallbacks + def get_rules(self, context): + """Given an event context return the rules for all users who are + currently in the room. + """ + state_group = context.state_group + + with (yield self.linearizer.queue(())): + if state_group and self.state_group == state_group: + defer.returnValue(self.rules_by_user) + + ret_rules_by_user = {} + missing_member_event_ids = {} + if state_group and self.state_group == context.prev_group: + # If we have a simple delta then we can reuse most of the previous + # results. + ret_rules_by_user = self.rules_by_user + current_state_ids = context.delta_ids + else: + current_state_ids = context.current_state_ids + + # Loop through to see which member events we've seen and have rules + # for and which we need to fetch + for key, event_id in current_state_ids.iteritems(): + if key[0] != EventTypes.Member: + continue + + res = self.member_map.get(event_id, None) + if res: + user_id, state = res + if state == Membership.JOIN: + rules = self.rules_by_user.get(user_id, None) + if rules: + ret_rules_by_user[user_id] = rules + continue + + user_id = key[1] + if not self.is_mine_id(user_id): + continue + + if self.store.get_if_app_services_interested_in_user( + user_id, exclusive=True + ): + continue + + # If a user has left a room we remove their push rule. If they + # joined then we readd it later in _update_rules_with_member_event_ids + ret_rules_by_user.pop(user_id, None) + missing_member_event_ids[user_id] = event_id + + if missing_member_event_ids: + # If we have some memebr events we haven't seen, look them up + # and fetch push rules for them if appropriate. + yield self._update_rules_with_member_event_ids( + ret_rules_by_user, missing_member_event_ids, state_group + ) + + defer.returnValue(ret_rules_by_user) + + @defer.inlineCallbacks + def _update_rules_with_member_event_ids(self, ret_rules_by_user, member_event_ids, + state_group): + """Update the partially filled rules_by_user dict by fetching rules for + any newly joined users in the `member_event_ids` list. + + Args: + ret_rules_by_user (dict): Partiallly filled dict of push rules. Gets + updated with any new rules. + member_event_ids (list): List of event ids for membership events that + have happened since the last time we filled rules_by_user + state_group: The state group we are currently computing push rules + for. Used when updating the cache. + """ + sequence = self.sequence + + rows = yield self.store._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=member_event_ids.values(), + retcols=('user_id', 'membership', 'event_id'), + keyvalues={}, + batch_size=500, + desc="_get_rules_for_member_event_ids", + ) + + members = { + row["event_id"]: (row["user_id"], row["membership"]) + for row in rows + } + + interested_in_user_ids = set(user_id for user_id, _ in members.itervalues()) + + if_users_with_pushers = yield self.store.get_if_users_have_pushers( + interested_in_user_ids, + on_invalidate=self.invalidate_all_cb, + ) + + user_ids = set( + uid for uid, have_pusher in if_users_with_pushers.iteritems() if have_pusher + ) + + users_with_receipts = yield self.store.get_users_with_read_receipts_in_room( + self.room_id, on_invalidate=self.invalidate_all_cb, + ) + + # any users with pushers must be ours: they have pushers + for uid in users_with_receipts: + if uid in interested_in_user_ids: + user_ids.add(uid) + + rules_by_user = yield self.store.bulk_get_push_rules( + user_ids, on_invalidate=self.invalidate_all_cb, + ) + + ret_rules_by_user.update( + item for item in rules_by_user.iteritems() if item[0] is not None + ) + + self.update_cache(sequence, members, ret_rules_by_user, state_group) + + def invalidate_all(self): + # Note: Don't hand this function directly to an invalidation callback + # as it keeps a reference to self and will stop this instance from being + # GC'd if it gets dropped from the rules_to_user cache. Instead use + # `self.invalidate_all_cb` + self.sequence += 1 + self.state_group = object() + self.member_map = {} + self.rules_by_user = {} + + def update_cache(self, sequence, members, rules_by_user, state_group): + if sequence == self.sequence: + self.member_map.update(members) + self.rules_by_user = rules_by_user + self.state_group = state_group diff --git a/synapse/storage/appservice.py b/synapse/storage/appservice.py index 514570561f..0e9e8d3452 100644 --- a/synapse/storage/appservice.py +++ b/synapse/storage/appservice.py @@ -39,12 +39,15 @@ class ApplicationServiceStore(SQLBaseStore): def get_app_services(self): return self.services_cache - def get_if_app_services_interested_in_user(self, user_id): + def get_if_app_services_interested_in_user(self, user_id, exclusive=False): """Check if the user is one associated with an app service """ for service in self.services_cache: if service.is_interested_in_user(user_id): - return True + if exclusive: + return service.is_exclusive_user(user_id) + else: + return True return False def get_app_service_by_user_id(self, user_id): diff --git a/synapse/storage/push_rule.py b/synapse/storage/push_rule.py index 0a819d32c5..65bad3fad6 100644 --- a/synapse/storage/push_rule.py +++ b/synapse/storage/push_rule.py @@ -163,7 +163,7 @@ class PushRuleStore(SQLBaseStore): local_users_in_room = set( u for u in users_in_room if self.hs.is_mine_id(u) - and not self.get_if_app_services_interested_in_user(u) + and not self.get_if_app_services_interested_in_user(u, exclusive=True) ) # users in the room who have pushers need to get push rules run because diff --git a/synapse/storage/state.py b/synapse/storage/state.py index 85acf2ad1e..a7c3d401d4 100644 --- a/synapse/storage/state.py +++ b/synapse/storage/state.py @@ -563,20 +563,22 @@ class StateStore(SQLBaseStore): where a `state_key` of `None` matches all state_keys for the `type`. """ - is_all, state_dict_ids = self._state_group_cache.get(group) + is_all, known_absent, state_dict_ids = self._state_group_cache.get(group) type_to_key = {} missing_types = set() + for typ, state_key in types: + key = (typ, state_key) if state_key is None: type_to_key[typ] = None - missing_types.add((typ, state_key)) + missing_types.add(key) else: if type_to_key.get(typ, object()) is not None: type_to_key.setdefault(typ, set()).add(state_key) - if (typ, state_key) not in state_dict_ids: - missing_types.add((typ, state_key)) + if key not in state_dict_ids and key not in known_absent: + missing_types.add(key) sentinel = object() @@ -590,7 +592,7 @@ class StateStore(SQLBaseStore): return True return False - got_all = not (missing_types or types is None) + got_all = is_all or not missing_types return { k: v for k, v in state_dict_ids.iteritems() @@ -607,7 +609,7 @@ class StateStore(SQLBaseStore): Args: group: The state group to lookup """ - is_all, state_dict_ids = self._state_group_cache.get(group) + is_all, _, state_dict_ids = self._state_group_cache.get(group) return state_dict_ids, is_all @@ -624,7 +626,7 @@ class StateStore(SQLBaseStore): missing_groups = [] if types is not None: for group in set(groups): - state_dict_ids, missing_types, got_all = self._get_some_state_from_cache( + state_dict_ids, _, got_all = self._get_some_state_from_cache( group, types ) results[group] = state_dict_ids @@ -653,19 +655,7 @@ class StateStore(SQLBaseStore): # Now we want to update the cache with all the things we fetched # from the database. for group, group_state_dict in group_to_state_dict.iteritems(): - if types: - # We delibrately put key -> None mappings into the cache to - # cache absence of the key, on the assumption that if we've - # explicitly asked for some types then we will probably ask - # for them again. - state_dict = { - (intern_string(etype), intern_string(state_key)): None - for (etype, state_key) in types - } - state_dict.update(results[group]) - results[group] = state_dict - else: - state_dict = results[group] + state_dict = results[group] state_dict.update( ((intern_string(k[0]), intern_string(k[1])), to_ascii(v)) @@ -677,17 +667,9 @@ class StateStore(SQLBaseStore): key=group, value=state_dict, full=(types is None), + known_absent=types, ) - # Remove all the entries with None values. The None values were just - # used for bookkeeping in the cache. - for group, state_dict in results.iteritems(): - results[group] = { - key: event_id - for key, event_id in state_dict.iteritems() - if event_id - } - defer.returnValue(results) def get_next_state_group(self): diff --git a/synapse/util/caches/dictionary_cache.py b/synapse/util/caches/dictionary_cache.py index cb6933c61c..d4105822b3 100644 --- a/synapse/util/caches/dictionary_cache.py +++ b/synapse/util/caches/dictionary_cache.py @@ -23,7 +23,17 @@ import logging logger = logging.getLogger(__name__) -class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "value"))): +class DictionaryEntry(namedtuple("DictionaryEntry", ("full", "known_absent", "value"))): + """Returned when getting an entry from the cache + + Attributes: + full (bool): Whether the cache has the full or dict or just some keys. + If not full then not all requested keys will necessarily be present + in `value` + known_absent (set): Keys that were looked up in the dict and were not + there. + value (dict): The full or partial dict value + """ def __len__(self): return len(self.value) @@ -58,21 +68,31 @@ class DictionaryCache(object): ) def get(self, key, dict_keys=None): + """Fetch an entry out of the cache + + Args: + key + dict_key(list): If given a set of keys then return only those keys + that exist in the cache. + + Returns: + DictionaryEntry + """ entry = self.cache.get(key, self.sentinel) if entry is not self.sentinel: self.metrics.inc_hits() if dict_keys is None: - return DictionaryEntry(entry.full, dict(entry.value)) + return DictionaryEntry(entry.full, entry.known_absent, dict(entry.value)) else: - return DictionaryEntry(entry.full, { + return DictionaryEntry(entry.full, entry.known_absent, { k: entry.value[k] for k in dict_keys if k in entry.value }) self.metrics.inc_misses() - return DictionaryEntry(False, {}) + return DictionaryEntry(False, set(), {}) def invalidate(self, key): self.check_thread() @@ -87,19 +107,34 @@ class DictionaryCache(object): self.sequence += 1 self.cache.clear() - def update(self, sequence, key, value, full=False): + def update(self, sequence, key, value, full=False, known_absent=None): + """Updates the entry in the cache + + Args: + sequence + key + value (dict): The value to update the cache with. + full (bool): Whether the given value is the full dict, or just a + partial subset there of. If not full then any existing entries + for the key will be updated. + known_absent (set): Set of keys that we know don't exist in the full + dict. + """ self.check_thread() if self.sequence == sequence: # Only update the cache if the caches sequence number matches the # number that the cache had before the SELECT was started (SYN-369) + if known_absent is None: + known_absent = set() if full: - self._insert(key, value) + self._insert(key, value, known_absent) else: - self._update_or_insert(key, value) + self._update_or_insert(key, value, known_absent) - def _update_or_insert(self, key, value): - entry = self.cache.setdefault(key, DictionaryEntry(False, {})) + def _update_or_insert(self, key, value, known_absent): + entry = self.cache.setdefault(key, DictionaryEntry(False, set(), {})) entry.value.update(value) + entry.known_absent.update(known_absent) - def _insert(self, key, value): - self.cache[key] = DictionaryEntry(True, value) + def _insert(self, key, value, known_absent): + self.cache[key] = DictionaryEntry(True, known_absent, value) diff --git a/tests/util/test_dict_cache.py b/tests/util/test_dict_cache.py index 272b71034a..bc92f85fa6 100644 --- a/tests/util/test_dict_cache.py +++ b/tests/util/test_dict_cache.py @@ -28,7 +28,7 @@ class DictCacheTestCase(unittest.TestCase): key = "test_simple_cache_hit_full" v = self.cache.get(key) - self.assertEqual((False, {}), v) + self.assertEqual((False, set(), {}), v) seq = self.cache.sequence test_value = {"test": "test_simple_cache_hit_full"} |