summary refs log tree commit diff
path: root/synapse/storage/databases/main/user_directory.py
diff options
context:
space:
mode:
Diffstat (limited to 'synapse/storage/databases/main/user_directory.py')
-rw-r--r--synapse/storage/databases/main/user_directory.py91
1 files changed, 72 insertions, 19 deletions
diff --git a/synapse/storage/databases/main/user_directory.py b/synapse/storage/databases/main/user_directory.py

index a9f2e93614..f2f9a5799a 100644 --- a/synapse/storage/databases/main/user_directory.py +++ b/synapse/storage/databases/main/user_directory.py
@@ -15,7 +15,7 @@ import logging import re -from typing import Any, Dict, Optional +from typing import Any, Dict, Iterable, Optional, Set, Tuple from synapse.api.constants import EventTypes, JoinRules from synapse.storage.database import DatabasePool @@ -365,10 +365,17 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): return False - def update_profile_in_user_dir(self, user_id, display_name, avatar_url): + async def update_profile_in_user_dir( + self, user_id: str, display_name: str, avatar_url: str + ) -> None: """ Update or add a user's profile in the user directory. """ + # If the display name or avatar URL are unexpected types, overwrite them. + if not isinstance(display_name, str): + display_name = None + if not isinstance(avatar_url, str): + avatar_url = None def _update_profile_in_user_dir_txn(txn): new_entry = self.db_pool.simple_upsert_txn( @@ -458,17 +465,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "update_profile_in_user_dir", _update_profile_in_user_dir_txn ) - def add_users_who_share_private_room(self, room_id, user_id_tuples): + async def add_users_who_share_private_room( + self, room_id: str, user_id_tuples: Iterable[Tuple[str, str]] + ) -> None: """Insert entries into the users_who_share_private_rooms table. The first user should be a local user. Args: - room_id (str) - user_id_tuples([(str, str)]): iterable of 2-tuple of user IDs. + room_id + user_id_tuples: iterable of 2-tuple of user IDs. """ def _add_users_who_share_room_txn(txn): @@ -484,17 +493,19 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): value_values=None, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_users_who_share_room", _add_users_who_share_room_txn ) - def add_users_in_public_rooms(self, room_id, user_ids): + async def add_users_in_public_rooms( + self, room_id: str, user_ids: Iterable[str] + ) -> None: """Insert entries into the users_who_share_private_rooms table. The first user should be a local user. Args: - room_id (str) - user_ids (list[str]) + room_id + user_ids """ def _add_users_in_public_rooms_txn(txn): @@ -508,11 +519,11 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): value_values=None, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "add_users_in_public_rooms", _add_users_in_public_rooms_txn ) - def delete_all_from_user_dir(self): + async def delete_all_from_user_dir(self) -> None: """Delete the entire user directory """ @@ -523,7 +534,7 @@ class UserDirectoryBackgroundUpdateStore(StateDeltasStore): txn.execute("DELETE FROM users_who_share_private_rooms") txn.call_after(self.get_user_in_directory.invalidate_all) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "delete_all_from_user_dir", _delete_all_from_user_dir_txn ) @@ -555,7 +566,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): def __init__(self, database: DatabasePool, db_conn, hs): super(UserDirectoryStore, self).__init__(database, db_conn, hs) - def remove_from_user_dir(self, user_id): + async def remove_from_user_dir(self, user_id: str) -> None: def _remove_from_user_dir_txn(txn): self.db_pool.simple_delete_txn( txn, table="user_directory", keyvalues={"user_id": user_id} @@ -578,7 +589,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): ) txn.call_after(self.get_user_in_directory.invalidate, (user_id,)) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "remove_from_user_dir", _remove_from_user_dir_txn ) @@ -605,14 +616,14 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): return user_ids - def remove_user_who_share_room(self, user_id, room_id): + async def remove_user_who_share_room(self, user_id: str, room_id: str) -> None: """ Deletes entries in the users_who_share_*_rooms table. The first user should be a local user. Args: - user_id (str) - room_id (str) + user_id + room_id """ def _remove_user_who_share_room_txn(txn): @@ -632,7 +643,7 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): keyvalues={"user_id": user_id, "room_id": room_id}, ) - return self.db_pool.runInteraction( + await self.db_pool.runInteraction( "remove_user_who_share_room", _remove_user_who_share_room_txn ) @@ -664,6 +675,48 @@ class UserDirectoryStore(UserDirectoryBackgroundUpdateStore): users.update(rows) return list(users) + @cached() + async def get_shared_rooms_for_users( + self, user_id: str, other_user_id: str + ) -> Set[str]: + """ + Returns the rooms that a local user shares with another local or remote user. + + Args: + user_id: The MXID of a local user + other_user_id: The MXID of the other user + + Returns: + A set of room ID's that the users share. + """ + + def _get_shared_rooms_for_users_txn(txn): + txn.execute( + """ + SELECT p1.room_id + FROM users_in_public_rooms as p1 + INNER JOIN users_in_public_rooms as p2 + ON p1.room_id = p2.room_id + AND p1.user_id = ? + AND p2.user_id = ? + UNION + SELECT room_id + FROM users_who_share_private_rooms + WHERE + user_id = ? + AND other_user_id = ? + """, + (user_id, other_user_id, user_id, other_user_id), + ) + rows = self.db_pool.cursor_to_dict(txn) + return rows + + rows = await self.db_pool.runInteraction( + "get_shared_rooms_for_users", _get_shared_rooms_for_users_txn + ) + + return {row["room_id"] for row in rows} + async def get_user_directory_stream_pos(self) -> int: return await self.db_pool.simple_select_one_onecol( table="user_directory_stream_pos",