summary refs log tree commit diff
diff options
context:
space:
mode:
authorNeil Johnson <neil@matrix.org>2024-04-11 22:19:04 +0100
committerNeil Johnson <neil@matrix.org>2024-04-11 22:19:04 +0100
commit74d133418220e34f3f06bb5a259772d0f7cd9737 (patch)
tree2ebd992e785012e260dead7354e507edd6a8f892
parentAlso check if first event matches the last in prev batch (#17066) (diff)
downloadsynapse-74d133418220e34f3f06bb5a259772d0f7cd9737.tar.xz
move _update_join_states to be processed by a ScheduledTask
-rw-r--r--synapse/handlers/profile.py55
1 files changed, 47 insertions, 8 deletions
diff --git a/synapse/handlers/profile.py b/synapse/handlers/profile.py
index e51e282a9f..b8f4ab47e4 100644
--- a/synapse/handlers/profile.py
+++ b/synapse/handlers/profile.py
@@ -20,7 +20,7 @@
 #
 import logging
 import random
-from typing import TYPE_CHECKING, Optional, Union
+from typing import TYPE_CHECKING, Optional, Tuple, Union
 
 from synapse.api.errors import (
     AuthError,
@@ -31,7 +31,15 @@ from synapse.api.errors import (
     SynapseError,
 )
 from synapse.storage.databases.main.media_repository import LocalMedia, RemoteMedia
-from synapse.types import JsonDict, Requester, UserID, create_requester
+from synapse.types import (
+    JsonDict,
+    JsonMapping,
+    Requester,
+    ScheduledTask,
+    TaskStatus,
+    UserID,
+    create_requester,
+)
 from synapse.util.caches.descriptors import cached
 from synapse.util.stringutils import parse_and_validate_mxc_uri
 
@@ -43,6 +51,8 @@ logger = logging.getLogger(__name__)
 MAX_DISPLAYNAME_LEN = 256
 MAX_AVATAR_URL_LEN = 1000
 
+UPDATE_JOIN_STATES_TASK_NAME = "update_join_states"
+
 
 class ProfileHandler:
     """Handles fetching and updating user profile information.
@@ -71,6 +81,11 @@ class ProfileHandler:
 
         self._third_party_rules = hs.get_module_api_callbacks().third_party_event_rules
 
+        self._task_scheduler = hs.get_task_scheduler()
+        self._task_scheduler.register_action(
+            self._update_join_states, UPDATE_JOIN_STATES_TASK_NAME
+        )
+
     async def get_profile(self, user_id: str, ignore_backoff: bool = True) -> JsonDict:
         target_user = UserID.from_string(user_id)
 
@@ -198,7 +213,13 @@ class ProfileHandler:
         )
 
         if propagate:
-            await self._update_join_states(requester, target_user)
+            await self._task_scheduler.schedule_task(
+                UPDATE_JOIN_STATES_TASK_NAME,
+                params={
+                    "requester": requester.serialize(),
+                    "target_user": target_user.to_string(),
+                },
+            )
 
     async def get_avatar_url(self, target_user: UserID) -> Optional[str]:
         if self.hs.is_mine(target_user):
@@ -291,7 +312,13 @@ class ProfileHandler:
         )
 
         if propagate:
-            await self._update_join_states(requester, target_user)
+            await self._task_scheduler.schedule_task(
+                UPDATE_JOIN_STATES_TASK_NAME,
+                params={
+                    "requester": requester.serialize(),
+                    "target_user": target_user.to_string(),
+                },
+            )
 
     @cached()
     async def check_avatar_size_and_mime_type(self, mxc: str) -> bool:
@@ -393,10 +420,21 @@ class ProfileHandler:
         return response
 
     async def _update_join_states(
-        self, requester: Requester, target_user: UserID
-    ) -> None:
+        self, task: ScheduledTask
+    ) -> Tuple[TaskStatus, Optional[JsonMapping], Optional[str]]:
+        """Updates join states following a change to display name or avatar
+
+        Args:
+            target_user: The owner of the queried profile. This is a str rather
+            than a UserID because the task_scheduler requires JSON serializable
+            parameters
+            requester: The user querying for the profile.
+        """
+        assert task.params is not None
+        requester = Requester.deserialize(self.store, task.params["requester"])
+        target_user = UserID.from_string(task.params["target_user"])
         if not self.hs.is_mine(target_user):
-            return
+            return TaskStatus.COMPLETE, None, None
 
         await self.request_ratelimiter.ratelimit(requester)
 
@@ -404,7 +442,7 @@ class ProfileHandler:
         if requester.shadow_banned:
             # We randomly sleep a bit just to annoy the requester.
             await self.clock.sleep(random.randint(1, 10))
-            return
+            return TaskStatus.COMPLETE, None, None
 
         room_ids = await self.store.get_rooms_for_user(target_user.to_string())
 
@@ -424,6 +462,7 @@ class ProfileHandler:
                 logger.warning(
                     "Failed to update join event for room %s - %s", room_id, str(e)
                 )
+        return TaskStatus.COMPLETE, None, None
 
     async def check_profile_query_allowed(
         self, target_user: UserID, requester: Optional[UserID] = None