diff --git a/changelog.d/12836.misc b/changelog.d/12836.misc
new file mode 100644
index 0000000000..85909c6a2d
--- /dev/null
+++ b/changelog.d/12836.misc
@@ -0,0 +1 @@
+Remove Mutual Rooms ([MSC2666](https://github.com/matrix-org/matrix-spec-proposals/pull/2666)) endpoint dependency on the User Directory.
\ No newline at end of file
diff --git a/synapse/rest/client/mutual_rooms.py b/synapse/rest/client/mutual_rooms.py
index 27bfaf0b29..38ef4e459f 100644
--- a/synapse/rest/client/mutual_rooms.py
+++ b/synapse/rest/client/mutual_rooms.py
@@ -42,21 +42,10 @@ class UserMutualRoomsServlet(RestServlet):
super().__init__()
self.auth = hs.get_auth()
self.store = hs.get_datastores().main
- self.user_directory_search_enabled = (
- hs.config.userdirectory.user_directory_search_enabled
- )
async def on_GET(
self, request: SynapseRequest, user_id: str
) -> Tuple[int, JsonDict]:
-
- if not self.user_directory_search_enabled:
- raise SynapseError(
- code=400,
- msg="User directory searching is disabled. Cannot determine shared rooms.",
- errcode=Codes.UNKNOWN,
- )
-
UserID.from_string(user_id)
requester = await self.auth.get_user_by_req(request)
@@ -67,8 +56,8 @@ class UserMutualRoomsServlet(RestServlet):
errcode=Codes.FORBIDDEN,
)
- rooms = await self.store.get_mutual_rooms_for_users(
- requester.user.to_string(), user_id
+ rooms = await self.store.get_mutual_rooms_between_users(
+ frozenset((requester.user.to_string(), user_id))
)
return 200, {"joined": list(rooms)}
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py
index cc528fcf2d..e222b7bd1f 100644
--- a/synapse/storage/databases/main/roommember.py
+++ b/synapse/storage/databases/main/roommember.py
@@ -670,6 +670,30 @@ class RoomMemberWorkerStore(EventsWorkerStore):
return user_who_share_room
+ @cached(cache_context=True, iterable=True)
+ async def get_mutual_rooms_between_users(
+ self, user_ids: FrozenSet[str], cache_context: _CacheContext
+ ) -> FrozenSet[str]:
+ """
+ Returns the set of rooms that all users in `user_ids` share.
+
+ Args:
+ user_ids: A frozen set of all users to investigate and return
+ overlapping joined rooms for.
+ cache_context
+ """
+ shared_room_ids: Optional[FrozenSet[str]] = None
+ for user_id in user_ids:
+ room_ids = await self.get_rooms_for_user(
+ user_id, on_invalidate=cache_context.invalidate
+ )
+ if shared_room_ids is not None:
+ shared_room_ids &= room_ids
+ else:
+ shared_room_ids = room_ids
+
+ return shared_room_ids or frozenset()
+
async def get_joined_users_from_context(
self, event: EventBase, context: EventContext
) -> Dict[str, ProfileInfo]:
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index 028db69af3..2282242e9d 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -729,49 +729,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
users.update(rows)
return list(users)
- async def get_mutual_rooms_for_users(
- self, user_id: str, other_user_id: str
- ) -> Set[str]:
- """
- Returns the rooms that a local user shares with another local or remote user.
-
- Args:
- user_id: The MXID of a local user
- other_user_id: The MXID of the other user
-
- Returns:
- A set of room ID's that the users share.
- """
-
- def _get_mutual_rooms_for_users_txn(
- txn: LoggingTransaction,
- ) -> List[Dict[str, str]]:
- txn.execute(
- """
- SELECT p1.room_id
- FROM users_in_public_rooms as p1
- INNER JOIN users_in_public_rooms as p2
- ON p1.room_id = p2.room_id
- AND p1.user_id = ?
- AND p2.user_id = ?
- UNION
- SELECT room_id
- FROM users_who_share_private_rooms
- WHERE
- user_id = ?
- AND other_user_id = ?
- """,
- (user_id, other_user_id, user_id, other_user_id),
- )
- rows = self.db_pool.cursor_to_dict(txn)
- return rows
-
- rows = await self.db_pool.runInteraction(
- "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn
- )
-
- return {row["room_id"] for row in rows}
-
async def get_user_directory_stream_pos(self) -> Optional[int]:
"""
Get the stream ID of the user directory stream.
diff --git a/tests/rest/client/test_mutual_rooms.py b/tests/rest/client/test_mutual_rooms.py
index 7b7d283bb6..a4327f7ace 100644
--- a/tests/rest/client/test_mutual_rooms.py
+++ b/tests/rest/client/test_mutual_rooms.py
@@ -36,12 +36,10 @@ class UserMutualRoomsTest(unittest.HomeserverTestCase):
def make_homeserver(self, reactor: MemoryReactor, clock: Clock) -> HomeServer:
config = self.default_config()
- config["update_user_directory"] = True
return self.setup_test_homeserver(config=config)
def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
self.store = hs.get_datastores().main
- self.handler = hs.get_user_directory_handler()
def _get_mutual_rooms(self, token: str, other_user: str) -> FakeChannel:
return self.make_request(
|