summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--synapse/rest/client/v2_alpha/shared_rooms.py2
-rw-r--r--synapse/storage/databases/main/user_directory.py36
2 files changed, 37 insertions, 1 deletions
diff --git a/synapse/rest/client/v2_alpha/shared_rooms.py b/synapse/rest/client/v2_alpha/shared_rooms.py
index 86feec2145..69ecd552ec 100644
--- a/synapse/rest/client/v2_alpha/shared_rooms.py
+++ b/synapse/rest/client/v2_alpha/shared_rooms.py
@@ -46,7 +46,7 @@ class UserSharedRoomsServlet(RestServlet):
                 errcode=Codes.FORBIDDEN,
             )
 
-        rooms = await self.store.get_rooms_in_common_for_users(
+        rooms = await self.store.get_shared_rooms_for_users(
             requester.user.to_string(), user_id
         )
 
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py
index af21fe457a..1c73e51c21 100644
--- a/synapse/storage/databases/main/user_directory.py
+++ b/synapse/storage/databases/main/user_directory.py
@@ -663,6 +663,42 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore):
         users.update(rows)
         return list(users)
 
+    @cached()
+    async def get_shared_rooms_for_users(self, user_id, other_user_id):
+        """
+        Returns the rooms that a user is in.
+
+        Args:
+            user_id(str): Must be a local user
+
+        Returns:
+            list: user_id
+        """
+        SQL = """
+            SELECT p1.room_id 
+            FROM users_in_public_rooms as p1
+            INNER JOIN users_in_public_rooms as p2
+                ON p1.room_id = p2.room_id
+                AND p1.user_id = ?
+                AND p2.user_id = ?
+            UNION
+            SELECT room_id
+            FROM users_who_share_private_rooms
+            WHERE
+                user_id = ?
+                AND other_user_id = ?;
+        """
+        rows = await self.db_pool.execute(
+            "get_shared_rooms_for_users",
+            None,
+            SQL,
+            user_id,
+            other_user_id,
+            user_id,
+            other_user_id
+        )
+        return list({row[0] for row in rows})
+
     def get_user_directory_stream_pos(self):
         return self.db_pool.simple_select_one_onecol(
             table="user_directory_stream_pos",