summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
authorErik Johnston <erik@matrix.org>2019-10-07 13:15:00 +0100
committerGitHub <noreply@github.com>2019-10-07 13:15:00 +0100
commit86f4705866ae61ee72b225b4d74fc77288d410c2 (patch)
tree5a6a73ed79549dd49109ecd16924d428b9afa9c1 /synapse
parentMerge pull request #6147 from matrix-org/babolivier/3pid-invite-revoked (diff)
parentFix bug where we didn't pull out event ID (diff)
downloadsynapse-86f4705866ae61ee72b225b4d74fc77288d410c2.tar.xz
Merge pull request #6159 from matrix-org/erikj/cache_memberships
Cache room membership lookups in _get_joined_users_from_context
Diffstat (limited to 'synapse')
-rw-r--r--synapse/storage/roommember.py64
1 files changed, 45 insertions, 19 deletions
diff --git a/synapse/storage/roommember.py b/synapse/storage/roommember.py
index 1550d827ba..59a89fad60 100644
--- a/synapse/storage/roommember.py
+++ b/synapse/storage/roommember.py
@@ -32,7 +32,7 @@ from synapse.storage.events_worker import EventsWorkerStore
 from synapse.types import get_domain_from_id
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
-from synapse.util.caches.descriptors import cached, cachedInlineCallbacks
+from synapse.util.caches.descriptors import cached, cachedInlineCallbacks, cachedList
 from synapse.util.metrics import Measure
 from synapse.util.stringutils import to_ascii
 
@@ -572,25 +572,10 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 missing_member_event_ids.append(event_id)
 
         if missing_member_event_ids:
-            rows = yield self._simple_select_many_batch(
-                table="room_memberships",
-                column="event_id",
-                iterable=missing_member_event_ids,
-                retcols=("user_id", "display_name", "avatar_url"),
-                keyvalues={"membership": Membership.JOIN},
-                batch_size=500,
-                desc="_get_joined_users_from_context",
-            )
-
-            users_in_room.update(
-                {
-                    to_ascii(row["user_id"]): ProfileInfo(
-                        avatar_url=to_ascii(row["avatar_url"]),
-                        display_name=to_ascii(row["display_name"]),
-                    )
-                    for row in rows
-                }
+            event_to_memberships = yield self._get_joined_profiles_from_event_ids(
+                missing_member_event_ids
             )
+            users_in_room.update((row for row in event_to_memberships.values() if row))
 
         if event is not None and event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
@@ -602,6 +587,47 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return users_in_room
 
+    @cached(max_entries=10000)
+    def _get_joined_profile_from_event_id(self, event_id):
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="_get_joined_profile_from_event_id",
+        list_name="event_ids",
+        inlineCallbacks=True,
+    )
+    def _get_joined_profiles_from_event_ids(self, event_ids):
+        """For given set of member event_ids check if they point to a join
+        event and if so return the associated user and profile info.
+
+        Args:
+            event_ids (Iterable[str]): The member event IDs to lookup
+
+        Returns:
+            Deferred[dict[str, Tuple[str, ProfileInfo]|None]]: Map from event ID
+            to `user_id` and ProfileInfo (or None if not join event).
+        """
+
+        rows = yield self._simple_select_many_batch(
+            table="room_memberships",
+            column="event_id",
+            iterable=event_ids,
+            retcols=("user_id", "display_name", "avatar_url", "event_id"),
+            keyvalues={"membership": Membership.JOIN},
+            batch_size=500,
+            desc="_get_membership_from_event_ids",
+        )
+
+        return {
+            row["event_id"]: (
+                row["user_id"],
+                ProfileInfo(
+                    avatar_url=row["avatar_url"], display_name=row["display_name"]
+                ),
+            )
+            for row in rows
+        }
+
     @cachedInlineCallbacks(max_entries=10000)
     def is_host_joined(self, room_id, host):
         if "%" in host or "_" in host: