summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/sync.py4
-rw-r--r--synapse/state/__init__.py13
-rw-r--r--synapse/storage/databases/main/roommember.py67
-rw-r--r--synapse/util/caches/descriptors.py26
4 files changed, 56 insertions, 54 deletions
diff --git a/synapse/handlers/sync.py b/synapse/handlers/sync.py
index b4d3f3958c..2d95b1fa24 100644
--- a/synapse/handlers/sync.py
+++ b/synapse/handlers/sync.py
@@ -2421,10 +2421,10 @@ class SyncHandler:
                     joined_room.room_id, joined_room.event_pos.stream
                 )
             )
-            users_in_room = await self.state.get_current_users_in_room(
+            user_ids_in_room = await self.state.get_current_user_ids_in_room(
                 joined_room.room_id, extrems
             )
-            if user_id in users_in_room:
+            if user_id in user_ids_in_room:
                 joined_room_ids.add(joined_room.room_id)
 
         return frozenset(joined_room_ids)
diff --git a/synapse/state/__init__.py b/synapse/state/__init__.py
index c355e4f98a..3047e1b1ad 100644
--- a/synapse/state/__init__.py
+++ b/synapse/state/__init__.py
@@ -44,7 +44,6 @@ from synapse.logging.context import ContextResourceUsage
 from synapse.replication.http.state import ReplicationUpdateCurrentStateRestServlet
 from synapse.state import v1, v2
 from synapse.storage.databases.main.events_worker import EventRedactBehaviour
-from synapse.storage.roommember import ProfileInfo
 from synapse.storage.state import StateFilter
 from synapse.types import StateMap
 from synapse.util.async_helpers import Linearizer
@@ -210,11 +209,11 @@ class StateHandler:
         ret = await self.resolve_state_groups_for_events(room_id, event_ids)
         return await ret.get_state(self._state_storage_controller, state_filter)
 
-    async def get_current_users_in_room(
+    async def get_current_user_ids_in_room(
         self, room_id: str, latest_event_ids: List[str]
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Set[str]:
         """
-        Get the users who are currently in a room.
+        Get the users IDs who are currently in a room.
 
         Note: This is much slower than using the equivalent method
         `DataStore.get_users_in_room` or `DataStore.get_users_in_room_with_profiles`,
@@ -225,15 +224,15 @@ class StateHandler:
             room_id: The ID of the room.
             latest_event_ids: Precomputed list of latest event IDs. Will be computed if None.
         Returns:
-            Dictionary of user IDs to their profileinfo.
+            Set of user IDs in the room.
         """
 
         assert latest_event_ids is not None
 
-        logger.debug("calling resolve_state_groups from get_current_users_in_room")
+        logger.debug("calling resolve_state_groups from get_current_user_ids_in_room")
         entry = await self.resolve_state_groups_for_events(room_id, latest_event_ids)
         state = await entry.get_state(self._state_storage_controller, StateFilter.all())
-        return await self.store.get_joined_users_from_state(room_id, state, entry)
+        return await self.store.get_joined_user_ids_from_state(room_id, state, entry)
 
     async def get_hosts_in_room_at_events(
         self, room_id: str, event_ids: Collection[str]
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index 827c1f1efd..0eb024a809 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -835,9 +835,9 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return shared_room_ids or frozenset()
 
-    async def get_joined_users_from_state(
+    async def get_joined_user_ids_from_state(
         self, room_id: str, state: StateMap[str], state_entry: "_StateCacheEntry"
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Set[str]:
         state_group: Union[object, int] = state_entry.state_group
         if not state_group:
             # If state_group is None it means it has yet to be assigned a
@@ -848,25 +848,25 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         assert state_group is not None
         with Measure(self._clock, "get_joined_users_from_state"):
-            return await self._get_joined_users_from_context(
+            return await self._get_joined_user_ids_from_context(
                 room_id, state_group, state, context=state_entry
             )
 
     @cached(num_args=2, iterable=True, max_entries=100000)
-    async def _get_joined_users_from_context(
+    async def _get_joined_user_ids_from_context(
         self,
         room_id: str,
         state_group: Union[object, int],
         current_state_ids: StateMap[str],
         event: Optional[EventBase] = None,
         context: Optional["_StateCacheEntry"] = None,
-    ) -> Dict[str, ProfileInfo]:
+    ) -> Set[str]:
         # We don't use `state_group`, it's there so that we can cache based
         # on it. However, it's important that it's never None, since two current_states
         # with a state_group of None are likely to be different.
         assert state_group is not None
 
-        users_in_room = {}
+        users_in_room = set()
         member_event_ids = [
             e_id
             for key, e_id in current_state_ids.items()
@@ -879,11 +879,11 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             # If we do then we can reuse that result and simply update it with
             # any membership changes in `delta_ids`
             if context.prev_group and context.delta_ids:
-                prev_res = self._get_joined_users_from_context.cache.get_immediate(
+                prev_res = self._get_joined_user_ids_from_context.cache.get_immediate(
                     (room_id, context.prev_group), None
                 )
-                if prev_res and isinstance(prev_res, dict):
-                    users_in_room = dict(prev_res)
+                if prev_res and isinstance(prev_res, set):
+                    users_in_room = prev_res
                     member_event_ids = [
                         e_id
                         for key, e_id in context.delta_ids.items()
@@ -891,7 +891,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                     ]
                     for etype, state_key in context.delta_ids:
                         if etype == EventTypes.Member:
-                            users_in_room.pop(state_key, None)
+                            users_in_room.discard(state_key)
 
         # We check if we have any of the member event ids in the event cache
         # before we ask the DB
@@ -908,42 +908,41 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             ev_entry = event_map.get(event_id)
             if ev_entry and not ev_entry.event.rejected_reason:
                 if ev_entry.event.membership == Membership.JOIN:
-                    users_in_room[ev_entry.event.state_key] = ProfileInfo(
-                        display_name=ev_entry.event.content.get("displayname", None),
-                        avatar_url=ev_entry.event.content.get("avatar_url", None),
-                    )
+                    users_in_room.add(ev_entry.event.state_key)
             else:
                 missing_member_event_ids.append(event_id)
 
         if missing_member_event_ids:
-            event_to_memberships = await self._get_joined_profiles_from_event_ids(
+            event_to_memberships = await self._get_user_ids_from_membership_event_ids(
                 missing_member_event_ids
             )
-            users_in_room.update(row for row in event_to_memberships.values() if row)
+            users_in_room.update(event_to_memberships.values())
 
         if event is not None and event.type == EventTypes.Member:
             if event.membership == Membership.JOIN:
                 if event.event_id in member_event_ids:
-                    users_in_room[event.state_key] = ProfileInfo(
-                        display_name=event.content.get("displayname", None),
-                        avatar_url=event.content.get("avatar_url", None),
-                    )
+                    users_in_room.add(event.state_key)
 
         return users_in_room
 
-    @cached(max_entries=10000)
-    def _get_joined_profile_from_event_id(
+    @cached(
+        max_entries=10000,
+        # This name matches the old function that has been replaced - the cache name
+        # is kept here to maintain backwards compatibility.
+        name="_get_joined_profile_from_event_id",
+    )
+    def _get_user_id_from_membership_event_id(
         self, event_id: str
     ) -> Optional[Tuple[str, ProfileInfo]]:
         raise NotImplementedError()
 
     @cachedList(
-        cached_method_name="_get_joined_profile_from_event_id",
+        cached_method_name="_get_user_id_from_membership_event_id",
         list_name="event_ids",
     )
-    async def _get_joined_profiles_from_event_ids(
+    async def _get_user_ids_from_membership_event_ids(
         self, event_ids: Iterable[str]
-    ) -> Dict[str, Optional[Tuple[str, ProfileInfo]]]:
+    ) -> Dict[str, str]:
         """For given set of member event_ids check if they point to a join
         event and if so return the associated user and profile info.
 
@@ -958,21 +957,13 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             table="room_memberships",
             column="event_id",
             iterable=event_ids,
-            retcols=("user_id", "display_name", "avatar_url", "event_id"),
+            retcols=("user_id", "event_id"),
             keyvalues={"membership": Membership.JOIN},
             batch_size=1000,
-            desc="_get_joined_profiles_from_event_ids",
+            desc="_get_user_ids_from_membership_event_ids",
         )
 
-        return {
-            row["event_id"]: (
-                row["user_id"],
-                ProfileInfo(
-                    avatar_url=row["avatar_url"], display_name=row["display_name"]
-                ),
-            )
-            for row in rows
-        }
+        return {row["event_id"]: row["user_id"] for row in rows}
 
     @cached(max_entries=10000)
     async def is_host_joined(self, room_id: str, host: str) -> bool:
@@ -1131,12 +1122,12 @@ class RoomMemberWorkerStore(EventsWorkerStore):
             else:
                 # The cache doesn't match the state group or prev state group,
                 # so we calculate the result from first principles.
-                joined_users = await self.get_joined_users_from_state(
+                joined_user_ids = await self.get_joined_user_ids_from_state(
                     room_id, state, state_entry
                 )
 
                 cache.hosts_to_joined_users = {}
-                for user_id in joined_users:
+                for user_id in joined_user_ids:
                     host = intern_string(get_domain_from_id(user_id))
                     cache.hosts_to_joined_users.setdefault(host, set()).add(user_id)
 
diff --git a/synapse/util/caches/descriptors.py b/synapse/util/caches/descriptors.py
index 867f315b2a..9d4bc89edb 100644
--- a/synapse/util/caches/descriptors.py
+++ b/synapse/util/caches/descriptors.py
@@ -73,8 +73,10 @@ class _CacheDescriptorBase:
         num_args: Optional[int],
         uncached_args: Optional[Collection[str]] = None,
         cache_context: bool = False,
+        name: Optional[str] = None,
     ):
         self.orig = orig
+        self.name = name or orig.__name__
 
         arg_spec = inspect.getfullargspec(orig)
         all_args = arg_spec.args
@@ -211,7 +213,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: LruCache[CacheKey, Any] = LruCache(
-            cache_name=self.orig.__name__,
+            cache_name=self.name,
             max_size=self.max_entries,
         )
 
@@ -241,7 +243,7 @@ class LruCacheDescriptor(_CacheDescriptorBase):
 
         wrapped = cast(_CachedFunction, _wrapped)
         wrapped.cache = cache
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -301,12 +303,14 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         cache_context: bool = False,
         iterable: bool = False,
         prune_unread_entries: bool = True,
+        name: Optional[str] = None,
     ):
         super().__init__(
             orig,
             num_args=num_args,
             uncached_args=uncached_args,
             cache_context=cache_context,
+            name=name,
         )
 
         if tree and self.num_args < 2:
@@ -321,7 +325,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
 
     def __get__(self, obj: Optional[Any], owner: Optional[Type]) -> Callable[..., Any]:
         cache: DeferredCache[CacheKey, Any] = DeferredCache(
-            name=self.orig.__name__,
+            name=self.name,
             max_entries=self.max_entries,
             tree=self.tree,
             iterable=self.iterable,
@@ -372,7 +376,7 @@ class DeferredCacheDescriptor(_CacheDescriptorBase):
         wrapped.cache = cache
         wrapped.num_args = self.num_args
 
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -393,6 +397,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
         cached_method_name: str,
         list_name: str,
         num_args: Optional[int] = None,
+        name: Optional[str] = None,
     ):
         """
         Args:
@@ -403,7 +408,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
                 but including list_name) to use as cache keys. Defaults to all
                 named args of the function.
         """
-        super().__init__(orig, num_args=num_args, uncached_args=None)
+        super().__init__(orig, num_args=num_args, uncached_args=None, name=name)
 
         self.list_name = list_name
 
@@ -525,7 +530,7 @@ class DeferredCacheListDescriptor(_CacheDescriptorBase):
             else:
                 return defer.succeed(results)
 
-        obj.__dict__[self.orig.__name__] = wrapped
+        obj.__dict__[self.name] = wrapped
 
         return wrapped
 
@@ -577,6 +582,7 @@ def cached(
     cache_context: bool = False,
     iterable: bool = False,
     prune_unread_entries: bool = True,
+    name: Optional[str] = None,
 ) -> Callable[[F], _CachedFunction[F]]:
     func = lambda orig: DeferredCacheDescriptor(
         orig,
@@ -587,13 +593,18 @@ def cached(
         cache_context=cache_context,
         iterable=iterable,
         prune_unread_entries=prune_unread_entries,
+        name=name,
     )
 
     return cast(Callable[[F], _CachedFunction[F]], func)
 
 
 def cachedList(
-    *, cached_method_name: str, list_name: str, num_args: Optional[int] = None
+    *,
+    cached_method_name: str,
+    list_name: str,
+    num_args: Optional[int] = None,
+    name: Optional[str] = None,
 ) -> Callable[[F], _CachedFunction[F]]:
     """Creates a descriptor that wraps a function in a `DeferredCacheListDescriptor`.
 
@@ -628,6 +639,7 @@ def cachedList(
         cached_method_name=cached_method_name,
         list_name=list_name,
         num_args=num_args,
+        name=name,
     )
 
     return cast(Callable[[F], _CachedFunction[F]], func)