diff options
author | Neil Johnson <neil@matrix.org> | 2024-04-11 22:19:04 +0100 |
---|---|---|
committer | Neil Johnson <neil@matrix.org> | 2024-04-11 22:19:04 +0100 |
commit | 74d133418220e34f3f06bb5a259772d0f7cd9737 (patch) | |
tree | 2ebd992e785012e260dead7354e507edd6a8f892 | |
parent | Also check if first event matches the last in prev batch (#17066) (diff) | |
download | synapse-74d133418220e34f3f06bb5a259772d0f7cd9737.tar.xz |
move _update_join_states to be processed by a ScheduledTask
-rw-r--r-- | synapse/handlers/profile.py | 55 |
1 files changed, 47 insertions, 8 deletions
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py index e51e282a9f..b8f4ab47e4 100644 --- a/synapse/handlers/profile.py +++ b/synapse/handlers/profile.py @@ -20,7 +20,7 @@ # import logging import random -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional, Tuple, Union from synapse.api.errors import ( AuthError, @@ -31,7 +31,15 @@ from synapse.api.errors import ( SynapseError, ) from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia -from synapse.types import JsonDict, Requester, UserID, create_requester +from synapse.types import ( + JsonDict, + JsonMapping, + Requester, + ScheduledTask, + TaskStatus, + UserID, + create_requester, +) from synapse.util.caches.descriptors import cached from synapse.util.stringutils import parse_and_validate_mxc_uri @@ -43,6 +51,8 @@ logger = logging.getLogger(__name__) MAX_DISPLAYNAME_LEN = 256 MAX_AVATAR_URL_LEN = 1000 +UPDATE_JOIN_STATES_TASK_NAME = "update_join_states" + class ProfileHandler: """Handles fetching and updating user profile information. @@ -71,6 +81,11 @@ class ProfileHandler: self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules + self._task_scheduler = hs.get_task_scheduler() + self._task_scheduler.register_action( + self._update_join_states, UPDATE_JOIN_STATES_TASK_NAME + ) + async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDict: target_user = UserID.from_string(user_id) @@ -198,7 +213,13 @@ class ProfileHandler: ) if propagate: - await self._update_join_states(requester, target_user) + await self._task_scheduler.schedule_task( + UPDATE_JOIN_STATES_TASK_NAME, + params={ + "requester": requester.serialize(), + "target_user": target_user.to_string(), + }, + ) async def get_avatar_url(self, target_user: UserID) -> Optional[str]: if self.hs.is_mine(target_user): @@ -291,7 +312,13 @@ class ProfileHandler: ) if propagate: - await self._update_join_states(requester, target_user) + await self._task_scheduler.schedule_task( + UPDATE_JOIN_STATES_TASK_NAME, + params={ + "requester": requester.serialize(), + "target_user": target_user.to_string(), + }, + ) @cached() async def check_avatar_size_and_mime_type(self, mxc: str) -> bool: @@ -393,10 +420,21 @@ class ProfileHandler: return response async def _update_join_states( - self, requester: Requester, target_user: UserID - ) -> None: + self, task: ScheduledTask + ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]: + """Updates join states following a change to display name or avatar + + Args: + target_user: The owner of the queried profile. This is a str rather + than a UserID because the task_scheduler requires JSON serializable + parameters + requester: The user querying for the profile. + """ + assert task.params is not None + requester = Requester.deserialize(self.store, task.params["requester"]) + target_user = UserID.from_string(task.params["target_user"]) if not self.hs.is_mine(target_user): - return + return TaskStatus.COMPLETE, None, None await self.request_ratelimiter.ratelimit(requester) @@ -404,7 +442,7 @@ class ProfileHandler: if requester.shadow_banned: # We randomly sleep a bit just to annoy the requester. await self.clock.sleep(random.randint(1, 10)) - return + return TaskStatus.COMPLETE, None, None room_ids = await self.store.get_rooms_for_user(target_user.to_string()) @@ -424,6 +462,7 @@ class ProfileHandler: logger.warning( "Failed to update join event for room %s - %s", room_id, str(e) ) + return TaskStatus.COMPLETE, None, None async def check_profile_query_allowed( self, target_user: UserID, requester: Optional[UserID] = None |