summary refs log tree commit diff
path: root/synapse/storage/roommember.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/roommember.py')
-rw-r--r--synapse/storage/roommember.py85
1 files changed, 60 insertions, 25 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py

index 4df8ebdacd..59a89fad60 100644 --- a/synapse/storage/roommember.py +++ b/synapse/storage/roommember.py
@@ -32,7 +32,8 @@ from synapse.storage.events_worker import EventsWorkerStore from synapse.types import get_domain_from_id from synapse.util.async_helpers import Linearizer from synapse.util.caches import intern_string -from synapse.util.caches.descriptors import cached, cachedInlineCallbacks +from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList +from synapse.util.metrics import Measure from synapse.util.stringutils import to_ascii logger = logging.getLogger(__name__) @@ -483,6 +484,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): ) return result + @defer.inlineCallbacks def get_joined_users_from_state(self, room_id, state_entry): state_group = state_entry.state_group if not state_group: @@ -492,9 +494,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - return self._get_joined_users_from_context( - room_id, state_group, state_entry.state, context=state_entry - ) + with Measure(self._clock, "get_joined_users_from_state"): + return ( + yield self._get_joined_users_from_context( + room_id, state_group, state_entry.state, context=state_entry + ) + ) @cachedInlineCallbacks( num_args=2, cache_context=True, iterable=True, max_entries=100000 @@ -567,25 +572,10 @@ class RoomMemberWorkerStore(EventsWorkerStore): missing_member_event_ids.append(event_id) if missing_member_event_ids: - rows = yield self._simple_select_many_batch( - table="room_memberships", - column="event_id", - iterable=missing_member_event_ids, - retcols=("user_id", "display_name", "avatar_url"), - keyvalues={"membership": Membership.JOIN}, - batch_size=500, - desc="_get_joined_users_from_context", - ) - - users_in_room.update( - { - to_ascii(row["user_id"]): ProfileInfo( - avatar_url=to_ascii(row["avatar_url"]), - display_name=to_ascii(row["display_name"]), - ) - for row in rows - } + event_to_memberships = yield self._get_joined_profiles_from_event_ids( + missing_member_event_ids ) + users_in_room.update((row for row in event_to_memberships.values() if row)) if event is not None and event.type == EventTypes.Member: if event.membership == Membership.JOIN: @@ -597,6 +587,47 @@ class RoomMemberWorkerStore(EventsWorkerStore): return users_in_room + @cached(max_entries=10000) + def _get_joined_profile_from_event_id(self, event_id): + raise NotImplementedError() + + @cachedList( + cached_method_name="_get_joined_profile_from_event_id", + list_name="event_ids", + inlineCallbacks=True, + ) + def _get_joined_profiles_from_event_ids(self, event_ids): + """For given set of member event_ids check if they point to a join + event and if so return the associated user and profile info. + + Args: + event_ids (Iterable[str]): The member event IDs to lookup + + Returns: + Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID + to `user_id` and ProfileInfo (or None if not join event). + """ + + rows = yield self._simple_select_many_batch( + table="room_memberships", + column="event_id", + iterable=event_ids, + retcols=("user_id", "display_name", "avatar_url", "event_id"), + keyvalues={"membership": Membership.JOIN}, + batch_size=500, + desc="_get_membership_from_event_ids", + ) + + return { + row["event_id"]: ( + row["user_id"], + ProfileInfo( + avatar_url=row["avatar_url"], display_name=row["display_name"] + ), + ) + for row in rows + } + @cachedInlineCallbacks(max_entries=10000) def is_host_joined(self, room_id, host): if "%" in host or "_" in host: @@ -669,6 +700,7 @@ class RoomMemberWorkerStore(EventsWorkerStore): return True + @defer.inlineCallbacks def get_joined_hosts(self, room_id, state_entry): state_group = state_entry.state_group if not state_group: @@ -678,9 +710,12 @@ class RoomMemberWorkerStore(EventsWorkerStore): # To do this we set the state_group to a new object as object() != object() state_group = object() - return self._get_joined_hosts( - room_id, state_group, state_entry.state, state_entry=state_entry - ) + with Measure(self._clock, "get_joined_hosts"): + return ( + yield self._get_joined_hosts( + room_id, state_group, state_entry.state, state_entry=state_entry + ) + ) @cachedInlineCallbacks(num_args=2, max_entries=10000, iterable=True) # @defer.inlineCallbacks