summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/handlers/presence.py112
-rw-r--r--synapse/storage/_base.py4
-rw-r--r--synapse/storage/databases/main/roommember.py83
3 files changed, 108 insertions, 91 deletions
diff --git a/synapse/handlers/presence.py b/synapse/handlers/presence.py
index 895ea63ed3..741504ba9f 100644
--- a/synapse/handlers/presence.py
+++ b/synapse/handlers/presence.py
@@ -34,7 +34,6 @@ from typing import (
     Callable,
     Collection,
     Dict,
-    FrozenSet,
     Generator,
     Iterable,
     List,
@@ -42,7 +41,6 @@ from typing import (
     Set,
     Tuple,
     Type,
-    Union,
 )
 
 from prometheus_client import Counter
@@ -68,7 +66,6 @@ from synapse.storage.databases.main import DataStore
 from synapse.streams import EventSource
 from synapse.types import JsonDict, StreamKeyType, UserID, get_domain_from_id
 from synapse.util.async_helpers import Linearizer
-from synapse.util.caches.descriptors import _CacheContext, cached
 from synapse.util.metrics import Measure
 from synapse.util.wheel_timer import WheelTimer
 
@@ -1656,15 +1653,18 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
                 # doesn't return. C.f. #5503.
                 return [], max_token
 
-            # Figure out which other users this user should receive updates for
-            users_interested_in = await self._get_interested_in(user, explicit_room_id)
+            # Figure out which other users this user should explicitly receive
+            # updates for
+            additional_users_interested_in = (
+                await self.get_presence_router().get_interested_users(user.to_string())
+            )
 
             # We have a set of users that we're interested in the presence of. We want to
             # cross-reference that with the users that have actually changed their presence.
 
             # Check whether this user should see all user updates
 
-            if users_interested_in == PresenceRouter.ALL_USERS:
+            if additional_users_interested_in == PresenceRouter.ALL_USERS:
                 # Provide presence state for all users
                 presence_updates = await self._filter_all_presence_updates_for_user(
                     user_id, include_offline, from_key
@@ -1673,34 +1673,47 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
                 return presence_updates, max_token
 
             # Make mypy happy. users_interested_in should now be a set
-            assert not isinstance(users_interested_in, str)
+            assert not isinstance(additional_users_interested_in, str)
+
+            # We always care about our own presence.
+            additional_users_interested_in.add(user_id)
+
+            if explicit_room_id:
+                user_ids = await self.store.get_users_in_room(explicit_room_id)
+                additional_users_interested_in.update(user_ids)
 
             # The set of users that we're interested in and that have had a presence update.
             # We'll actually pull the presence updates for these users at the end.
-            interested_and_updated_users: Union[Set[str], FrozenSet[str]] = set()
+            interested_and_updated_users: Collection[str]
 
             if from_key is not None:
                 # First get all users that have had a presence update
                 updated_users = stream_change_cache.get_all_entities_changed(from_key)
 
                 # Cross-reference users we're interested in with those that have had updates.
-                # Use a slightly-optimised method for processing smaller sets of updates.
-                if updated_users is not None and len(updated_users) < 500:
-                    # For small deltas, it's quicker to get all changes and then
-                    # cross-reference with the users we're interested in
+                if updated_users is not None:
+                    # If we have the full list of changes for presence we can
+                    # simply check which ones share a room with the user.
                     get_updates_counter.labels("stream").inc()
-                    for other_user_id in updated_users:
-                        if other_user_id in users_interested_in:
-                            # mypy thinks this variable could be a FrozenSet as it's possibly set
-                            # to one in the `get_entities_changed` call below, and `add()` is not
-                            # method on a FrozenSet. That doesn't affect us here though, as
-                            # `interested_and_updated_users` is clearly a set() above.
-                            interested_and_updated_users.add(other_user_id)  # type: ignore
+
+                    sharing_users = await self.store.do_users_share_a_room(
+                        user_id, updated_users
+                    )
+
+                    interested_and_updated_users = (
+                        sharing_users.union(additional_users_interested_in)
+                    ).intersection(updated_users)
+
                 else:
                     # Too many possible updates. Find all users we can see and check
                     # if any of them have changed.
                     get_updates_counter.labels("full").inc()
 
+                    users_interested_in = (
+                        await self.store.get_users_who_share_room_with_user(user_id)
+                    )
+                    users_interested_in.update(additional_users_interested_in)
+
                     interested_and_updated_users = (
                         stream_change_cache.get_entities_changed(
                             users_interested_in, from_key
@@ -1709,7 +1722,10 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
             else:
                 # No from_key has been specified. Return the presence for all users
                 # this user is interested in
-                interested_and_updated_users = users_interested_in
+                interested_and_updated_users = (
+                    await self.store.get_users_who_share_room_with_user(user_id)
+                )
+                interested_and_updated_users.update(additional_users_interested_in)
 
             # Retrieve the current presence state for each user
             users_to_state = await self.get_presence_handler().current_state_for_users(
@@ -1804,62 +1820,6 @@ class PresenceEventSource(EventSource[int, UserPresenceState]):
     def get_current_key(self) -> int:
         return self.store.get_current_presence_token()
 
-    @cached(num_args=2, cache_context=True)
-    async def _get_interested_in(
-        self,
-        user: UserID,
-        explicit_room_id: Optional[str] = None,
-        cache_context: Optional[_CacheContext] = None,
-    ) -> Union[Set[str], str]:
-        """Returns the set of users that the given user should see presence
-        updates for.
-
-        Args:
-            user: The user to retrieve presence updates for.
-            explicit_room_id: The users that are in the room will be returned.
-
-        Returns:
-            A set of user IDs to return presence updates for, or "ALL" to return all
-            known updates.
-        """
-        user_id = user.to_string()
-        users_interested_in = set()
-        users_interested_in.add(user_id)  # So that we receive our own presence
-
-        # cache_context isn't likely to ever be None due to the @cached decorator,
-        # but we can't have a non-optional argument after the optional argument
-        # explicit_room_id either. Assert cache_context is not None so we can use it
-        # without mypy complaining.
-        assert cache_context
-
-        # Check with the presence router whether we should poll additional users for
-        # their presence information
-        additional_users = await self.get_presence_router().get_interested_users(
-            user.to_string()
-        )
-        if additional_users == PresenceRouter.ALL_USERS:
-            # If the module requested that this user see the presence updates of *all*
-            # users, then simply return that instead of calculating what rooms this
-            # user shares
-            return PresenceRouter.ALL_USERS
-
-        # Add the additional users from the router
-        users_interested_in.update(additional_users)
-
-        # Find the users who share a room with this user
-        users_who_share_room = await self.store.get_users_who_share_room_with_user(
-            user_id, on_invalidate=cache_context.invalidate
-        )
-        users_interested_in.update(users_who_share_room)
-
-        if explicit_room_id:
-            user_ids = await self.store.get_users_in_room(
-                explicit_room_id, on_invalidate=cache_context.invalidate
-            )
-            users_interested_in.update(user_ids)
-
-        return users_interested_in
-
 
 def handle_timeouts(
     user_states: List[UserPresenceState],
diff --git a/synapse/storage/_base.py b/synapse/storage/_base.py
index a2f8310388..e30f9c76d4 100644
--- a/synapse/storage/_base.py
+++ b/synapse/storage/_base.py
@@ -80,6 +80,10 @@ class SQLBaseStore(metaclass=ABCMeta):
             )
             self._attempt_to_invalidate_cache("get_local_users_in_room", (room_id,))
 
+            # There's no easy way of invalidating this cache for just the users
+            # that have changed, so we just clear the entire thing.
+            self._attempt_to_invalidate_cache("does_pair_of_users_share_a_room", None)
+
         for user_id in members_changed:
             self._attempt_to_invalidate_cache(
                 "get_user_in_room_with_profile", (room_id, user_id)
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index df6b82660e..e2cccc688c 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -21,6 +21,7 @@ from typing import (
     FrozenSet,
     Iterable,
     List,
+    Mapping,
     Optional,
     Set,
     Tuple,
@@ -55,6 +56,7 @@ from synapse.types import JsonDict, PersistedEventPosition, StateMap, get_domain
 from synapse.util.async_helpers import Linearizer
 from synapse.util.caches import intern_string
 from synapse.util.caches.descriptors import _CacheContext, cached, cachedList
+from synapse.util.iterutils import batch_iter
 from synapse.util.metrics import Measure
 
 if TYPE_CHECKING:
@@ -183,7 +185,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
                 self._check_safe_current_state_events_membership_updated_txn,
             )
 
-    @cached(max_entries=100000, iterable=True, prune_unread_entries=False)
+    @cached(max_entries=100000, iterable=True)
     async def get_users_in_room(self, room_id: str) -> List[str]:
         return await self.db_pool.runInteraction(
             "get_users_in_room", self.get_users_in_room_txn, room_id
@@ -561,7 +563,7 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return results_dict.get("membership"), results_dict.get("event_id")
 
-    @cached(max_entries=500000, iterable=True, prune_unread_entries=False)
+    @cached(max_entries=500000, iterable=True)
     async def get_rooms_for_user_with_stream_ordering(
         self, user_id: str
     ) -> FrozenSet[GetRoomsForUserWithStreamOrdering]:
@@ -732,25 +734,76 @@ class RoomMemberWorkerStore(EventsWorkerStore):
         )
         return frozenset(r.room_id for r in rooms)
 
-    @cached(
-        max_entries=500000,
-        cache_context=True,
-        iterable=True,
-        prune_unread_entries=False,
+    @cached(max_entries=10000)
+    async def does_pair_of_users_share_a_room(
+        self, user_id: str, other_user_id: str
+    ) -> bool:
+        raise NotImplementedError()
+
+    @cachedList(
+        cached_method_name="does_pair_of_users_share_a_room", list_name="other_user_ids"
     )
-    async def get_users_who_share_room_with_user(
-        self, user_id: str, cache_context: _CacheContext
+    async def _do_users_share_a_room(
+        self, user_id: str, other_user_ids: Collection[str]
+    ) -> Mapping[str, Optional[bool]]:
+        """Return mapping from user ID to whether they share a room with the
+        given user.
+
+        Note: `None` and `False` are equivalent and mean they don't share a
+        room.
+        """
+
+        def do_users_share_a_room_txn(
+            txn: LoggingTransaction, user_ids: Collection[str]
+        ) -> Dict[str, bool]:
+            clause, args = make_in_list_sql_clause(
+                self.database_engine, "state_key", user_ids
+            )
+
+            # This query works by fetching both the list of rooms for the target
+            # user and the set of other users, and then checking if there is any
+            # overlap.
+            sql = f"""
+                SELECT b.state_key
+                FROM (
+                    SELECT room_id FROM current_state_events
+                    WHERE type = 'm.room.member' AND membership = 'join' AND state_key = ?
+                ) AS a
+                INNER JOIN (
+                    SELECT room_id, state_key FROM current_state_events
+                    WHERE type = 'm.room.member' AND membership = 'join' AND {clause}
+                ) AS b using (room_id)
+                LIMIT 1
+            """
+
+            txn.execute(sql, (user_id, *args))
+            return {u: True for u, in txn}
+
+        to_return = {}
+        for batch_user_ids in batch_iter(other_user_ids, 1000):
+            res = await self.db_pool.runInteraction(
+                "do_users_share_a_room", do_users_share_a_room_txn, batch_user_ids
+            )
+            to_return.update(res)
+
+        return to_return
+
+    async def do_users_share_a_room(
+        self, user_id: str, other_user_ids: Collection[str]
     ) -> Set[str]:
+        """Return the set of users who share a room with the first users"""
+
+        user_dict = await self._do_users_share_a_room(user_id, other_user_ids)
+
+        return {u for u, share_room in user_dict.items() if share_room}
+
+    async def get_users_who_share_room_with_user(self, user_id: str) -> Set[str]:
         """Returns the set of users who share a room with `user_id`"""
-        room_ids = await self.get_rooms_for_user(
-            user_id, on_invalidate=cache_context.invalidate
-        )
+        room_ids = await self.get_rooms_for_user(user_id)
 
         user_who_share_room = set()
         for room_id in room_ids:
-            user_ids = await self.get_users_in_room(
-                room_id, on_invalidate=cache_context.invalidate
-            )
+            user_ids = await self.get_users_in_room(room_id)
             user_who_share_room.update(user_ids)
 
         return user_who_share_room