diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py
index 0123e28f99..2a17cbc9e9 100644
--- a/synapse/storage/user_directory.py
+++ b/synapse/storage/user_directory.py
@@ -273,17 +273,38 @@ class UserDirectoryStore(SQLBaseStore):
desc="get_users_in_public_due_to_room",
)
+ @defer.inlineCallbacks
def get_users_in_dir_due_to_room(self, room_id):
"""Get all user_ids that are in the room directory becuase they're
in the given room_id
"""
- return self._simple_select_onecol(
+ user_ids_dir = yield self._simple_select_onecol(
table="user_directory",
keyvalues={"room_id": room_id},
retcol="user_id",
desc="get_users_in_dir_due_to_room",
)
+ user_ids_pub = yield self._simple_select_onecol(
+ table="users_in_pubic_room",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids_share = yield self._simple_select_onecol(
+ table="users_who_share_rooms",
+ keyvalues={"room_id": room_id},
+ retcol="user_id",
+ desc="get_users_in_dir_due_to_room",
+ )
+
+ user_ids = set(user_ids_dir)
+ user_ids.update(user_ids_pub)
+ user_ids.update(user_ids_share)
+
+ defer.returnValue(user_ids)
+
@defer.inlineCallbacks
def get_all_rooms(self):
"""Get all room_ids we've ever known about, in ascending order of "size"
@@ -398,6 +419,94 @@ class UserDirectoryStore(SQLBaseStore):
"remove_user_who_share_room", _remove_user_who_share_room_txn
)
+ @cached(max_entries=500000)
+ def get_if_users_share_a_room(self, user_id, other_user_id):
+ """Gets if users share a room.
+
+ Args:
+ user_id (str): Must be a local user_id
+ other_user_id (str)
+
+ Returns:
+ bool|None: None if they don't share a room, otherwise whether they
+ share a private room or not.
+ """
+ return self._simple_select_one_onecol(
+ table="users_who_share_rooms",
+ keyvalues={
+ "user_id": user_id,
+ "other_user_id": other_user_id,
+ },
+ retcol="share_private",
+ allow_none=True,
+ )
+
+ @cachedInlineCallbacks(max_entries=500000, iterable=True)
+ def get_users_who_share_room_from_dir(self, user_id):
+ """Returns the set of users who share a room with `user_id`
+
+ Args:
+ user_id(str): Must be a local user
+
+ Returns:
+ dict: user_id -> share_private mapping
+ """
+ rows = yield self._simple_select_list(
+ table="users_who_share_rooms",
+ keyvalues={
+ "user_id": user_id,
+ },
+ retcols=("other_user_id", "share_private",),
+ desc="get_users_who_share_room_with_user",
+ )
+
+ defer.returnValue({
+ row["other_user_id"]: row["share_private"]
+ for row in rows
+ })
+
+ def get_users_in_share_dir_with_room_id(self, user_id, room_id):
+ """Get all user tuples that are in the users_who_share_rooms due to the
+ given room_id.
+
+ Returns:
+ [(user_id, other_user_id)]: where one of the two will match the given
+ user_id.
+ """
+ sql = """
+ SELECT user_id, other_user_id FROM users_who_share_rooms
+ WHERE room_id = ? AND (user_id = ? OR other_user_id = ?)
+ """
+ return self._execute(
+ "get_users_in_share_dir_with_room_id", None, sql, room_id, user_id, user_id
+ )
+
+ @defer.inlineCallbacks
+ def get_rooms_in_common_for_users(self, user_id, other_user_id):
+ """Given two user_ids find out the list of rooms they share.
+ """
+ sql = """
+ SELECT room_id FROM (
+ SELECT c.room_id FROM current_state_events AS c
+ INNER JOIN room_memberships USING (event_id)
+ WHERE type = 'm.room.member'
+ AND membership = 'join'
+ AND state_key = ?
+ ) AS f1 INNER JOIN (
+ SELECT c.room_id FROM current_state_events AS c
+ INNER JOIN room_memberships USING (event_id)
+ WHERE type = 'm.room.member'
+ AND membership = 'join'
+ AND state_key = ?
+ ) f2 USING (room_id)
+ """
+
+ rows = yield self._execute(
+ "get_rooms_in_common_for_users", None, sql, user_id, other_user_id
+ )
+
+ defer.returnValue([room_id for room_id, in rows])
+
def delete_all_from_user_dir(self):
"""Delete the entire user directory
"""
|