diff options
Diffstat (limited to 'synapse/storage')
-rw-r--r-- | synapse/storage/databases/main/roommember.py | 24 | ||||
-rw-r--r-- | synapse/storage/databases/main/user_directory.py | 43 |
2 files changed, 24 insertions, 43 deletions
diff --git a/synapse/storage/databases/main/roommember.py b/synapse/storage/databases/main/roommember.py index cc528fcf2d..e222b7bd1f 100644 --- a/synapse/storage/databases/main/roommember.py +++ b/synapse/storage/databases/main/roommember.py @@ -670,6 +670,30 @@ class RoomMemberWorkerStore(EventsWorkerStore): return user_who_share_room + @cached(cache_context=True, iterable=True) + async def get_mutual_rooms_between_users( + self, user_ids: FrozenSet[str], cache_context: _CacheContext + ) -> FrozenSet[str]: + """ + Returns the set of rooms that all users in `user_ids` share. + + Args: + user_ids: A frozen set of all users to investigate and return + overlapping joined rooms for. + cache_context + """ + shared_room_ids: Optional[FrozenSet[str]] = None + for user_id in user_ids: + room_ids = await self.get_rooms_for_user( + user_id, on_invalidate=cache_context.invalidate + ) + if shared_room_ids is not None: + shared_room_ids &= room_ids + else: + shared_room_ids = room_ids + + return shared_room_ids or frozenset() + async def get_joined_users_from_context( self, event: EventBase, context: EventContext ) -> Dict[str, ProfileInfo]: diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py index 028db69af3..2282242e9d 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py @@ -729,49 +729,6 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) - async def get_mutual_rooms_for_users( - self, user_id: str, other_user_id: str - ) -> Set[str]: - """ - Returns the rooms that a local user shares with another local or remote user. - - Args: - user_id: The MXID of a local user - other_user_id: The MXID of the other user - - Returns: - A set of room ID's that the users share. - """ - - def _get_mutual_rooms_for_users_txn( - txn: LoggingTransaction, - ) -> List[Dict[str, str]]: - txn.execute( - """ - SELECT p1.room_id - FROM users_in_public_rooms as p1 - INNER JOIN users_in_public_rooms as p2 - ON p1.room_id = p2.room_id - AND p1.user_id = ? - AND p2.user_id = ? - UNION - SELECT room_id - FROM users_who_share_private_rooms - WHERE - user_id = ? - AND other_user_id = ? - """, - (user_id, other_user_id, user_id, other_user_id), - ) - rows = self.db_pool.cursor_to_dict(txn) - return rows - - rows = await self.db_pool.runInteraction( - "get_mutual_rooms_for_users", _get_mutual_rooms_for_users_txn - ) - - return {row["room_id"] for row in rows} - async def get_user_directory_stream_pos(self) -> Optional[int]: """ Get the stream ID of the user directory stream. |