summary refs log tree commit diff
path: root/synapse
diff options
context:
space:
mode:
Diffstat (limited to 'synapse')
-rw-r--r--synapse/rest/client/mutual_rooms.py15
-rw-r--r--synapse/storage/databases/main/roommember.py24
-rw-r--r--synapse/storage/databases/main/user_directory.py43
3 files changed, 26 insertions, 56 deletions
diff --git a/synapse/rest/client/mutual_rooms.py b/synapse/rest/client/mutual_rooms.py
index 27bfaf0b29..38ef4e459f 100644
--- a/synapse/rest/client/mutual_rooms.py
+++ b/synapse/rest/client/mutual_rooms.py
@@ -42,21 +42,10 @@ class UserMutualRoomsServlet(RestServlet):
         super().__init__()
         self.auth = hs.get_auth()
         self.store = hs.get_datastores().main
-        self.user_directory_search_enabled = (
-            hs.config.userdirectory.user_directory_search_enabled
-        )
 
     async def on_GET(
         self, request: SynapseRequest, user_id: str
     ) -> Tuple[int, JsonDict]:
-
-        if not self.user_directory_search_enabled:
-            raise SynapseError(
-                code=400,
-                msg="User directory searching is disabled. Cannot determine shared rooms.",
-                errcode=Codes.UNKNOWN,
-            )
-
         UserID.from_string(user_id)
 
         requester = await self.auth.get_user_by_req(request)
@@ -67,8 +56,8 @@ class UserMutualRoomsServlet(RestServlet):
                 errcode=Codes.FORBIDDEN,
             )
 
-        rooms = await self.store.get_mutual_rooms_for_users(
-            requester.user.to_string(), user_id
+        rooms = await self.store.get_mutual_rooms_between_users(
+            frozenset((requester.user.to_string(), user_id))
         )
 
         return 200, {"joined": list(rooms)}
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index cc528fcf2d..e222b7bd1f 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -670,6 +670,30 @@ class RoomMemberWorkerStore(EventsWorkerStore):
 
         return user_who_share_room
 
+    @cached(cache_context=True, iterable=True)
+    async def get_mutual_rooms_between_users(
+        self, user_ids: FrozenSet[str], cache_context: _CacheContext
+    ) -> FrozenSet[str]:
+        """
+        Returns the set of rooms that all users in `user_ids` share.
+
+        Args:
+            user_ids: A frozen set of all users to investigate and return
+              overlapping joined rooms for.
+            cache_context
+        """
+        shared_room_ids: Optional[FrozenSet[str]] = None
+        for user_id in user_ids:
+            room_ids = await self.get_rooms_for_user(
+                user_id, on_invalidate=cache_context.invalidate
+            )
+            if shared_room_ids is not None:
+                shared_room_ids &= room_ids
+            else:
+                shared_room_ids = room_ids
+
+        return shared_room_ids or frozenset()
+
     async def get_joined_users_from_context(
         self, event: EventBase, context: EventContext
     ) -> Dict[str, ProfileInfo]:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 028db69af3..2282242e9d 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -729,49 +729,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         users.update(rows)
         return list(users)
 
-    async def get_mutual_rooms_for_users(
-        self, user_id: str, other_user_id: str
-    ) -> Set[str]:
-        """
-        Returns the rooms that a local user shares with another local or remote user.
-
-        Args:
-            user_id: The MXID of a local user
-            other_user_id: The MXID of the other user
-
-        Returns:
-            A set of room ID's that the users share.
-        """
-
-        def _get_mutual_rooms_for_users_txn(
-            txn: LoggingTransaction,
-        ) -> List[Dict[str, str]]:
-            txn.execute(
-                """
-                SELECT p1.room_id
-                FROM users_in_public_rooms as p1
-                INNER JOIN users_in_public_rooms as p2
-                    ON p1.room_id = p2.room_id
-                    AND p1.user_id = ?
-                    AND p2.user_id = ?
-                UNION
-                SELECT room_id
-                FROM users_who_share_private_rooms
-                WHERE
-                    user_id = ?
-                    AND other_user_id = ?
-                """,
-                (user_id, other_user_id, user_id, other_user_id),
-            )
-            rows = self.db_pool.cursor_to_dict(txn)
-            return rows
-
-        rows = await self.db_pool.runInteraction(
-            "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn
-        )
-
-        return {row["room_id"] for row in rows}
-
     async def get_user_directory_stream_pos(self) -> Optional[int]:
         """
         Get the stream ID of the user directory stream.