diff options
Diffstat (limited to 'synapse/storage/user_directory.py')
-rw-r--r-- | synapse/storage/user_directory.py | 296 |
1 files changed, 266 insertions, 30 deletions
diff --git a/synapse/storage/user_directory.py b/synapse/storage/user_directory.py index 137aca2881..52b184fe78 100644 --- a/synapse/storage/user_directory.py +++ b/synapse/storage/user_directory.py @@ -16,16 +16,19 @@ from twisted.internet import defer from ._base import SQLBaseStore + from synapse.util.caches.descriptors import cached, cachedInlineCallbacks from synapse.api.constants import EventTypes, JoinRules from synapse.storage.engines import PostgresEngine, Sqlite3Engine from synapse.types import get_domain_from_id, get_localpart_from_id import re +import logging +logger = logging.getLogger(__name__) -class UserDirectoryStore(SQLBaseStore): +class UserDirectoryStore(SQLBaseStore): @cachedInlineCallbacks(cache_context=True) def is_room_world_readable_or_publicly_joinable(self, room_id, cache_context): """Check if the room is either world_readable or publically joinable @@ -270,27 +273,240 @@ class UserDirectoryStore(SQLBaseStore): desc="get_users_in_public_due_to_room", ) + @defer.inlineCallbacks def get_users_in_dir_due_to_room(self, room_id): """Get all user_ids that are in the room directory becuase they're in the given room_id """ - return self._simple_select_onecol( + user_ids_dir = yield self._simple_select_onecol( table="user_directory", keyvalues={"room_id": room_id}, retcol="user_id", desc="get_users_in_dir_due_to_room", ) + user_ids_pub = yield self._simple_select_onecol( + table="users_in_pubic_room", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_dir_due_to_room", + ) + + user_ids_share = yield self._simple_select_onecol( + table="users_who_share_rooms", + keyvalues={"room_id": room_id}, + retcol="user_id", + desc="get_users_in_dir_due_to_room", + ) + + user_ids = set(user_ids_dir) + user_ids.update(user_ids_pub) + user_ids.update(user_ids_share) + + defer.returnValue(user_ids) + + @defer.inlineCallbacks def get_all_rooms(self): - """Get all room_ids we've ever known about + """Get all room_ids we've ever known about, in ascending order of "size" """ - return self._simple_select_onecol( - table="current_state_events", - keyvalues={}, - retcol="DISTINCT room_id", - desc="get_all_rooms", + sql = """ + SELECT room_id FROM current_state_events + GROUP BY room_id + ORDER BY count(*) ASC + """ + rows = yield self._execute("get_all_rooms", None, sql) + defer.returnValue([room_id for room_id, in rows]) + + def add_users_who_share_room(self, room_id, share_private, user_id_tuples): + """Insert entries into the users_who_share_rooms table. The first + user should be a local user. + + Args: + room_id (str) + share_private (bool): Is the room private + user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. + """ + def _add_users_who_share_room_txn(txn): + self._simple_insert_many_txn( + txn, + table="users_who_share_rooms", + values=[ + { + "user_id": user_id, + "other_user_id": other_user_id, + "room_id": room_id, + "share_private": share_private, + } + for user_id, other_user_id in user_id_tuples + ], + ) + for user_id, other_user_id in user_id_tuples: + txn.call_after( + self.get_users_who_share_room_from_dir.invalidate, + (user_id,), + ) + txn.call_after( + self.get_if_users_share_a_room.invalidate, + (user_id, other_user_id), + ) + return self.runInteraction( + "add_users_who_share_room", _add_users_who_share_room_txn + ) + + def update_users_who_share_room(self, room_id, share_private, user_id_sets): + """Updates entries in the users_who_share_rooms table. The first + user should be a local user. + + Args: + room_id (str) + share_private (bool): Is the room private + user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. + """ + def _update_users_who_share_room_txn(txn): + sql = """ + UPDATE users_who_share_rooms + SET room_id = ?, share_private = ? + WHERE user_id = ? AND other_user_id = ? + """ + txn.executemany( + sql, + ( + (room_id, share_private, uid, oid) + for uid, oid in user_id_sets + ) + ) + for user_id, other_user_id in user_id_sets: + txn.call_after( + self.get_users_who_share_room_from_dir.invalidate, + (user_id,), + ) + txn.call_after( + self.get_if_users_share_a_room.invalidate, + (user_id, other_user_id), + ) + return self.runInteraction( + "update_users_who_share_room", _update_users_who_share_room_txn + ) + + def remove_user_who_share_room(self, user_id, other_user_id): + """Deletes entries in the users_who_share_rooms table. The first + user should be a local user. + + Args: + room_id (str) + share_private (bool): Is the room private + user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. + """ + def _remove_user_who_share_room_txn(txn): + self._simple_delete_txn( + txn, + table="users_who_share_rooms", + keyvalues={ + "user_id": user_id, + "other_user_id": other_user_id, + }, + ) + txn.call_after( + self.get_users_who_share_room_from_dir.invalidate, + (user_id,), + ) + txn.call_after( + self.get_if_users_share_a_room.invalidate, + (user_id, other_user_id), + ) + + return self.runInteraction( + "remove_user_who_share_room", _remove_user_who_share_room_txn + ) + + @cached(max_entries=500000) + def get_if_users_share_a_room(self, user_id, other_user_id): + """Gets if users share a room. + + Args: + user_id (str): Must be a local user_id + other_user_id (str) + + Returns: + bool|None: None if they don't share a room, otherwise whether they + share a private room or not. + """ + return self._simple_select_one_onecol( + table="users_who_share_rooms", + keyvalues={ + "user_id": user_id, + "other_user_id": other_user_id, + }, + retcol="share_private", + allow_none=True, ) + @cachedInlineCallbacks(max_entries=500000, iterable=True) + def get_users_who_share_room_from_dir(self, user_id): + """Returns the set of users who share a room with `user_id` + + Args: + user_id(str): Must be a local user + + Returns: + dict: user_id -> share_private mapping + """ + rows = yield self._simple_select_list( + table="users_who_share_rooms", + keyvalues={ + "user_id": user_id, + }, + retcols=("other_user_id", "share_private",), + desc="get_users_who_share_room_with_user", + ) + + defer.returnValue({ + row["other_user_id"]: row["share_private"] + for row in rows + }) + + def get_users_in_share_dir_with_room_id(self, user_id, room_id): + """Get all user tuples that are in the users_who_share_rooms due to the + given room_id. + + Returns: + [(user_id, other_user_id)]: where one of the two will match the given + user_id. + """ + sql = """ + SELECT user_id, other_user_id FROM users_who_share_rooms + WHERE room_id = ? AND (user_id = ? OR other_user_id = ?) + """ + return self._execute( + "get_users_in_share_dir_with_room_id", None, sql, room_id, user_id, user_id + ) + + @defer.inlineCallbacks + def get_rooms_in_common_for_users(self, user_id, other_user_id): + """Given two user_ids find out the list of rooms they share. + """ + sql = """ + SELECT room_id FROM ( + SELECT c.room_id FROM current_state_events AS c + INNER JOIN room_memberships USING (event_id) + WHERE type = 'm.room.member' + AND membership = 'join' + AND state_key = ? + ) AS f1 INNER JOIN ( + SELECT c.room_id FROM current_state_events AS c + INNER JOIN room_memberships USING (event_id) + WHERE type = 'm.room.member' + AND membership = 'join' + AND state_key = ? + ) f2 USING (room_id) + """ + + rows = yield self._execute( + "get_rooms_in_common_for_users", None, sql, user_id, other_user_id + ) + + defer.returnValue([room_id for room_id, in rows]) + def delete_all_from_user_dir(self): """Delete the entire user directory """ @@ -298,8 +514,11 @@ class UserDirectoryStore(SQLBaseStore): txn.execute("DELETE FROM user_directory") txn.execute("DELETE FROM user_directory_search") txn.execute("DELETE FROM users_in_pubic_room") + txn.execute("DELETE FROM users_who_share_rooms") txn.call_after(self.get_user_in_directory.invalidate_all) txn.call_after(self.get_user_in_public_room.invalidate_all) + txn.call_after(self.get_users_who_share_room_from_dir.invalidate_all) + txn.call_after(self.get_if_users_share_a_room.invalidate_all) return self.runInteraction( "delete_all_from_user_dir", _delete_all_from_user_dir_txn ) @@ -392,7 +611,7 @@ class UserDirectoryStore(SQLBaseStore): ) @defer.inlineCallbacks - def search_user_dir(self, search_term, limit): + def search_user_dir(self, user_id, search_term, limit): """Searches for users in directory Returns: @@ -418,46 +637,63 @@ class UserDirectoryStore(SQLBaseStore): # The array of numbers are the weights for the various part of the # search: (domain, _, display name, localpart) sql = """ - SELECT user_id, display_name, avatar_url + SELECT d.user_id, display_name, avatar_url FROM user_directory_search - INNER JOIN user_directory USING (user_id) - INNER JOIN users_in_pubic_room USING (user_id) - WHERE vector @@ to_tsquery('english', ?) + INNER JOIN user_directory AS d USING (user_id) + LEFT JOIN users_in_pubic_room AS p USING (user_id) + LEFT JOIN ( + SELECT other_user_id AS user_id FROM users_who_share_rooms + WHERE user_id = ? AND share_private + ) AS s USING (user_id) + WHERE + (s.user_id IS NOT NULL OR p.user_id IS NOT NULL) + AND vector @@ to_tsquery('english', ?) ORDER BY - 2 * ts_rank_cd( - '{0.1, 0.1, 0.9, 1.0}', - vector, - to_tsquery('english', ?), - 8 - ) - + ts_rank_cd( - '{0.1, 0.1, 0.9, 1.0}', - vector, - to_tsquery('english', ?), - 8 + (CASE WHEN s.user_id IS NOT NULL THEN 4.0 ELSE 1.0 END) + * (CASE WHEN display_name IS NOT NULL THEN 1.2 ELSE 1.0 END) + * (CASE WHEN avatar_url IS NOT NULL THEN 1.2 ELSE 1.0 END) + * ( + 3 * ts_rank_cd( + '{0.1, 0.1, 0.9, 1.0}', + vector, + to_tsquery('english', ?), + 8 + ) + + ts_rank_cd( + '{0.1, 0.1, 0.9, 1.0}', + vector, + to_tsquery('english', ?), + 8 + ) ) DESC, display_name IS NULL, avatar_url IS NULL LIMIT ? """ - args = (full_query, exact_query, prefix_query, limit + 1,) + args = (user_id, full_query, exact_query, prefix_query, limit + 1,) elif isinstance(self.database_engine, Sqlite3Engine): search_query = _parse_query_sqlite(search_term) sql = """ - SELECT user_id, display_name, avatar_url + SELECT d.user_id, display_name, avatar_url FROM user_directory_search - INNER JOIN user_directory USING (user_id) - INNER JOIN users_in_pubic_room USING (user_id) - WHERE value MATCH ? + INNER JOIN user_directory AS d USING (user_id) + LEFT JOIN users_in_pubic_room AS p USING (user_id) + LEFT JOIN ( + SELECT other_user_id AS user_id FROM users_who_share_rooms + WHERE user_id = ? AND share_private + ) AS s USING (user_id) + WHERE + (s.user_id IS NOT NULL OR p.user_id IS NOT NULL) + AND value MATCH ? ORDER BY rank(matchinfo(user_directory_search)) DESC, display_name IS NULL, avatar_url IS NULL LIMIT ? """ - args = (search_query, limit + 1) + args = (user_id, search_query, limit + 1) else: # This should be unreachable. raise Exception("Unrecognized database engine") |