summary refs log tree commit diff
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-10-03 17:11:04 +0100
committerErik Johnston <erik@matrix.org>2019-10-03 17:11:04 +0100
commitc8145af8a9446911bbb52fc0114eebf4eebede4b (patch)
treec7e78d229acd794196d39ae74f3728cf1169a44e
parentMerge branch 'master' into develop (diff)
downloadsynapse-c8145af8a9446911bbb52fc0114eebf4eebede4b.tar.xz
Cache room membership lookups in _get_joined_users_from_context
Diffstat (limited to '')
-rw-r--r--synapse/storage/roommember.py64
1 files changed, 45 insertions, 19 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 4df8ebdacd..f11bbd05f8 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -16,6 +16,7 @@
 
 import logging
 from collections import namedtuple
+from typing import Iterable
 
 from six import iteritems, itervalues
 
@@ -32,7 +33,7 @@ from synapse.storage.events_worker import EventsWorkerStore
 from synapse.types import get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 from synapse.util.stringutils import to_ascii
 
 logger = logging.getLogger(__name__)
@@ -567,25 +568,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 missing_member_event_ids.append(event_id)
 
         if missing_member_event_ids:
-            rows = yield self._simple_select_many_batch(
-                table="room_memberships",
-                column="event_id",
-                iterable=missing_member_event_ids,
-                retcols=("user_id", "display_name", "avatar_url"),
-                keyvalues={"membership": Membership.JOIN},
-                batch_size=500,
-                desc="_get_joined_users_from_context",
-            )
-
-            users_in_room.update(
-                {
-                    to_ascii(row["user_id"]): ProfileInfo(
-                        avatar_url=to_ascii(row["avatar_url"]),
-                        display_name=to_ascii(row["display_name"]),
-                    )
-                    for row in rows
-                }
+            event_to_memberships = yield self._get_membership_from_event_ids(
+                missing_member_event_ids
             )
+            users_in_room.update((row for row in event_to_memberships.values() if row))
 
         if event is not None and event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
@@ -597,6 +583,46 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return users_in_room
 
+    @cached(max_entries=10000)
+    def _get_membership_from_event_id(self, event_id):
+        raise NotADirectoryError()
+
+    @cachedList(
+        cached_method_name="_get_membership_from_event_id",
+        list_name="event_ids",
+        inlineCallbacks=True,
+    )
+    def _get_membership_from_event_ids(self, event_ids: Iterable[str]):
+        """Lookup profile info for set of member event IDs.
+
+        Args:
+            event_ids: The member event IDs to lookup
+
+        Returns:
+            Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
+            to `user_id` and ProfileInfo (or None if couldn't find event).
+        """
+
+        rows = yield self._simple_select_many_batch(
+            table="room_memberships",
+            column="event_id",
+            iterable=event_ids,
+            retcols=("user_id", "display_name", "avatar_url"),
+            keyvalues={"membership": Membership.JOIN},
+            batch_size=500,
+            desc="_get_membership_from_event_ids",
+        )
+
+        return {
+            row["event_id"]: (
+                row["user_id"],
+                ProfileInfo(
+                    avatar_url=row["avatar_url"], display_name=row["display_name"]
+                ),
+            )
+            for row in rows
+        }
+
     @cachedInlineCallbacks(max_entries=10000)
     def is_host_joined(self, room_id, host):
         if "%" in host or "_" in host: