summary refs log tree commit diff
path: root/synapse/handlers/message.py
diff options
context:
space:
mode:
authorShay <hillerys@element.io>2022-10-12 11:01:00 -0700
committerGitHub <noreply@github.com>2022-10-12 11:01:00 -0700
commitb6baa46db078c3ef9e6c5751bccb8d2e1c5c5402 (patch)
treeea875cfe0e2023373822f81ed0ac6d681b25c3eb /synapse/handlers/message.py
parentReturn the thread ID properly down sync. (#14159) (diff)
downloadsynapse-b6baa46db078c3ef9e6c5751bccb8d2e1c5c5402.tar.xz
Fix a bug where the joined hosts for a given event were not being properly cached (#14125)
Diffstat (limited to 'synapse/handlers/message.py')
-rw-r--r--synapse/handlers/message.py91
1 files changed, 47 insertions, 44 deletions
diff --git a/synapse/handlers/message.py b/synapse/handlers/message.py
index da1acea275..4e55ebba0b 100644
--- a/synapse/handlers/message.py
+++ b/synapse/handlers/message.py
@@ -1390,7 +1390,7 @@ class EventCreationHandler:
                             extra_users=extra_users,
                         ),
                         run_in_background(
-                            self.cache_joined_hosts_for_event, event, context
+                            self.cache_joined_hosts_for_events, events_and_context
                         ).addErrback(
                             log_failure, "cache_joined_hosts_for_event failed"
                         ),
@@ -1491,62 +1491,65 @@ class EventCreationHandler:
                 await self.store.remove_push_actions_from_staging(event.event_id)
             raise
 
-    async def cache_joined_hosts_for_event(
-        self, event: EventBase, context: EventContext
+    async def cache_joined_hosts_for_events(
+        self, events_and_context: List[Tuple[EventBase, EventContext]]
     ) -> None:
-        """Precalculate the joined hosts at the event, when using Redis, so that
+        """Precalculate the joined hosts at each of the given events, when using Redis, so that
         external federation senders don't have to recalculate it themselves.
         """
 
-        if not self._external_cache.is_enabled():
-            return
-
-        # If external cache is enabled we should always have this.
-        assert self._external_cache_joined_hosts_updates is not None
+        for event, _ in events_and_context:
+            if not self._external_cache.is_enabled():
+                return
 
-        # We actually store two mappings, event ID -> prev state group,
-        # state group -> joined hosts, which is much more space efficient
-        # than event ID -> joined hosts.
-        #
-        # Note: We have to cache event ID -> prev state group, as we don't
-        # store that in the DB.
-        #
-        # Note: We set the state group -> joined hosts cache if it hasn't been
-        # set for a while, so that the expiry time is reset.
+            # If external cache is enabled we should always have this.
+            assert self._external_cache_joined_hosts_updates is not None
 
-        state_entry = await self.state.resolve_state_groups_for_events(
-            event.room_id, event_ids=event.prev_event_ids()
-        )
+            # We actually store two mappings, event ID -> prev state group,
+            # state group -> joined hosts, which is much more space efficient
+            # than event ID -> joined hosts.
+            #
+            # Note: We have to cache event ID -> prev state group, as we don't
+            # store that in the DB.
+            #
+            # Note: We set the state group -> joined hosts cache if it hasn't been
+            # set for a while, so that the expiry time is reset.
 
-        if state_entry.state_group:
-            await self._external_cache.set(
-                "event_to_prev_state_group",
-                event.event_id,
-                state_entry.state_group,
-                expiry_ms=60 * 60 * 1000,
+            state_entry = await self.state.resolve_state_groups_for_events(
+                event.room_id, event_ids=event.prev_event_ids()
             )
 
-            if state_entry.state_group in self._external_cache_joined_hosts_updates:
-                return
+            if state_entry.state_group:
+                await self._external_cache.set(
+                    "event_to_prev_state_group",
+                    event.event_id,
+                    state_entry.state_group,
+                    expiry_ms=60 * 60 * 1000,
+                )
 
-            state = await state_entry.get_state(
-                self._storage_controllers.state, StateFilter.all()
-            )
-            with opentracing.start_active_span("get_joined_hosts"):
-                joined_hosts = await self.store.get_joined_hosts(
-                    event.room_id, state, state_entry
+                if state_entry.state_group in self._external_cache_joined_hosts_updates:
+                    return
+
+                state = await state_entry.get_state(
+                    self._storage_controllers.state, StateFilter.all()
                 )
+                with opentracing.start_active_span("get_joined_hosts"):
+                    joined_hosts = await self.store.get_joined_hosts(
+                        event.room_id, state, state_entry
+                    )
 
-            # Note that the expiry times must be larger than the expiry time in
-            # _external_cache_joined_hosts_updates.
-            await self._external_cache.set(
-                "get_joined_hosts",
-                str(state_entry.state_group),
-                list(joined_hosts),
-                expiry_ms=60 * 60 * 1000,
-            )
+                # Note that the expiry times must be larger than the expiry time in
+                # _external_cache_joined_hosts_updates.
+                await self._external_cache.set(
+                    "get_joined_hosts",
+                    str(state_entry.state_group),
+                    list(joined_hosts),
+                    expiry_ms=60 * 60 * 1000,
+                )
 
-            self._external_cache_joined_hosts_updates[state_entry.state_group] = None
+                self._external_cache_joined_hosts_updates[
+                    state_entry.state_group
+                ] = None
 
     async def _validate_canonical_alias(
         self,