diff options
author | Erik Johnston <erik@matrix.org> | 2017-05-02 10:36:35 +0100 |
---|---|---|
committer | Erik Johnston <erik@matrix.org> | 2017-05-02 10:36:35 +0100 |
commit | 7166854f4169999fee0cd40a5ed389cc684b6dc8 (patch) | |
tree | 7443dd5e6dc7ef0d83a1a1109d3ff57ff8379acd /synapse/storage | |
parent | Merge pull request #2080 from matrix-org/erikj/filter_speed (diff) | |
download | synapse-7166854f4169999fee0cd40a5ed389cc684b6dc8.tar.xz |
Add cache for get_current_hosts_in_room
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/roommember.py | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py index 7ad2198d96..1c0fa8a680 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py @@ -482,6 +482,44 @@ class RoomMemberStore(SQLBaseStore): defer.returnValue(False) + def get_joined_hosts(self, room_id, state_group, state_ids): + if not state_group: + # If state_group is None it means it has yet to be assigned a + # state group, i.e. we need to make sure that calls with a state_group + # of None don't hit previous cached calls with a None state_group. + # To do this we set the state_group to a new object as object() != object() + state_group = object() + + return self._get_joined_hosts( + room_id, state_group, state_ids + ) + + @cachedInlineCallbacks(num_args=3) + def _get_joined_hosts(self, room_id, state_group, current_state_ids): + # We don't use `state_group`, its there so that we can cache based + # on it. However, its important that its never None, since two current_state's + # with a state_group of None are likely to be different. + # See bulk_get_push_rules_for_room for how we work around this. + assert state_group is not None + + joined_hosts = set() + for (etype, state_key), event_id in current_state_ids.items(): + if etype == EventTypes.Member: + try: + host = get_domain_from_id(state_key) + except: + logger.warn("state_key not user_id: %s", state_key) + continue + + if host in joined_hosts: + continue + + event = yield self.get_event(event_id, allow_none=True) + if event and event.content["membership"] == Membership.JOIN: + joined_hosts.add(host) + + defer.returnValue(joined_hosts) + @defer.inlineCallbacks def _background_add_membership_profile(self, progress, batch_size): target_min_stream_id = progress.get( |