summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/15233.misc1
-rw-r--r--changelog.d/15737.feature1
-rw-r--r--changelog.d/15755.misc1
-rw-r--r--changelog.d/15758.bugfix1
-rw-r--r--changelog.d/15770.bugfix1
-rw-r--r--changelog.d/15772.doc1
-rw-r--r--synapse/events/snapshot.py159
-rw-r--r--synapse/federation/federation_client.py5
-rw-r--r--synapse/federation/federation_server.py4
-rw-r--r--synapse/handlers/pagination.py137
-rw-r--r--synapse/storage/controllers/persist_events.py5
-rw-r--r--synapse/storage/databases/main/events.py15
-rw-r--r--synapse/storage/databases/main/events_worker.py2
-rw-r--r--synapse/util/__init__.py5
-rw-r--r--synapse/util/caches/lrucache.py8
-rw-r--r--tests/events/test_snapshot.py3
-rw-r--r--tests/storage/databases/main/test_events_worker.py49
-rw-r--r--tests/storage/test_event_chain.py5
-rw-r--r--tests/test_state.py11
19 files changed, 327 insertions, 87 deletions
diff --git a/changelog.d/15233.misc b/changelog.d/15233.misc
new file mode 100644
index 0000000000..1dff00bf3c
--- /dev/null
+++ b/changelog.d/15233.misc
@@ -0,0 +1 @@
+Replace `EventContext` fields `prev_group` and `delta_ids` with field `state_group_deltas`.
diff --git a/changelog.d/15737.feature b/changelog.d/15737.feature
new file mode 100644
index 0000000000..9a547b5ebd
--- /dev/null
+++ b/changelog.d/15737.feature
@@ -0,0 +1 @@
+Improve `/messages` response time by avoiding backfill when we already have messages to return.
diff --git a/changelog.d/15755.misc b/changelog.d/15755.misc
new file mode 100644
index 0000000000..a65340d380
--- /dev/null
+++ b/changelog.d/15755.misc
@@ -0,0 +1 @@
+Fix requesting multiple keys at once over federation, related to [MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983).
diff --git a/changelog.d/15758.bugfix b/changelog.d/15758.bugfix
new file mode 100644
index 0000000000..cabe25ca24
--- /dev/null
+++ b/changelog.d/15758.bugfix
@@ -0,0 +1 @@
+Avoid invalidating a cache that was just prefilled.
diff --git a/changelog.d/15770.bugfix b/changelog.d/15770.bugfix
new file mode 100644
index 0000000000..a65340d380
--- /dev/null
+++ b/changelog.d/15770.bugfix
@@ -0,0 +1 @@
+Fix requesting multiple keys at once over federation, related to [MSC3983](https://github.com/matrix-org/matrix-spec-proposals/pull/3983).
diff --git a/changelog.d/15772.doc b/changelog.d/15772.doc
new file mode 100644
index 0000000000..4d6c933c71
--- /dev/null
+++ b/changelog.d/15772.doc
@@ -0,0 +1 @@
+Document `looping_call()` functionality that will wait for the given function to finish before scheduling another.
diff --git a/synapse/events/snapshot.py b/synapse/events/snapshot.py
index e7e8225b8e..a43498ed4d 100644
--- a/synapse/events/snapshot.py
+++ b/synapse/events/snapshot.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 from abc import ABC, abstractmethod
-from typing import TYPE_CHECKING, List, Optional, Tuple
+from typing import TYPE_CHECKING, Dict, List, Optional, Tuple
 
 import attr
 from immutabledict import immutabledict
@@ -107,33 +107,32 @@ class EventContext(UnpersistedEventContextBase):
         state_delta_due_to_event: If `state_group` and `state_group_before_event` are not None
             then this is the delta of the state between the two groups.
 
-        prev_group: If it is known, ``state_group``'s prev_group. Note that this being
-            None does not necessarily mean that ``state_group`` does not have
-            a prev_group!
+        state_group_deltas: If not empty, this is a dict collecting a mapping of the state
+            difference between state groups.
 
-            If the event is a state event, this is normally the same as
-            ``state_group_before_event``.
+            The keys are a tuple of two integers: the initial group and final state group.
+            The corresponding value is a state map representing the state delta between
+            these state groups.
 
-            If ``state_group`` is None (ie, the event is an outlier), ``prev_group``
-            will always also be ``None``.
+            The dictionary is expected to have at most two entries with state groups of:
 
-            Note that this *not* (necessarily) the state group associated with
-            ``_prev_state_ids``.
+            1. The state group before the event and after the event.
+            2. The state group preceding the state group before the event and the
+               state group before the event.
 
-        delta_ids: If ``prev_group`` is not None, the state delta between ``prev_group``
-            and ``state_group``.
+            This information is collected and stored as part of an optimization for persisting
+            events.
 
         partial_state: if True, we may be storing this event with a temporary,
             incomplete state.
     """
 
     _storage: "StorageControllers"
+    state_group_deltas: Dict[Tuple[int, int], StateMap[str]]
     rejected: Optional[str] = None
     _state_group: Optional[int] = None
     state_group_before_event: Optional[int] = None
     _state_delta_due_to_event: Optional[StateMap[str]] = None
-    prev_group: Optional[int] = None
-    delta_ids: Optional[StateMap[str]] = None
     app_service: Optional[ApplicationService] = None
 
     partial_state: bool = False
@@ -145,16 +144,14 @@ class EventContext(UnpersistedEventContextBase):
         state_group_before_event: Optional[int],
         state_delta_due_to_event: Optional[StateMap[str]],
         partial_state: bool,
-        prev_group: Optional[int] = None,
-        delta_ids: Optional[StateMap[str]] = None,
+        state_group_deltas: Dict[Tuple[int, int], StateMap[str]],
     ) -> "EventContext":
         return EventContext(
             storage=storage,
             state_group=state_group,
             state_group_before_event=state_group_before_event,
             state_delta_due_to_event=state_delta_due_to_event,
-            prev_group=prev_group,
-            delta_ids=delta_ids,
+            state_group_deltas=state_group_deltas,
             partial_state=partial_state,
         )
 
@@ -163,7 +160,7 @@ class EventContext(UnpersistedEventContextBase):
         storage: "StorageControllers",
     ) -> "EventContext":
         """Return an EventContext instance suitable for persisting an outlier event"""
-        return EventContext(storage=storage)
+        return EventContext(storage=storage, state_group_deltas={})
 
     async def persist(self, event: EventBase) -> "EventContext":
         return self
@@ -183,13 +180,15 @@ class EventContext(UnpersistedEventContextBase):
             "state_group": self._state_group,
             "state_group_before_event": self.state_group_before_event,
             "rejected": self.rejected,
-            "prev_group": self.prev_group,
+            "state_group_deltas": _encode_state_group_delta(self.state_group_deltas),
             "state_delta_due_to_event": _encode_state_dict(
                 self._state_delta_due_to_event
             ),
-            "delta_ids": _encode_state_dict(self.delta_ids),
             "app_service_id": self.app_service.id if self.app_service else None,
             "partial_state": self.partial_state,
+            # add dummy delta_ids and prev_group for backwards compatibility
+            "delta_ids": None,
+            "prev_group": None,
         }
 
     @staticmethod
@@ -204,17 +203,24 @@ class EventContext(UnpersistedEventContextBase):
         Returns:
             The event context.
         """
+        # workaround for backwards/forwards compatibility: if the input doesn't have a value
+        # for "state_group_deltas" just assign an empty dict
+        state_group_deltas = input.get("state_group_deltas", None)
+        if state_group_deltas:
+            state_group_deltas = _decode_state_group_delta(state_group_deltas)
+        else:
+            state_group_deltas = {}
+
         context = EventContext(
             # We use the state_group and prev_state_id stuff to pull the
             # current_state_ids out of the DB and construct prev_state_ids.
             storage=storage,
             state_group=input["state_group"],
             state_group_before_event=input["state_group_before_event"],
-            prev_group=input["prev_group"],
+            state_group_deltas=state_group_deltas,
             state_delta_due_to_event=_decode_state_dict(
                 input["state_delta_due_to_event"]
             ),
-            delta_ids=_decode_state_dict(input["delta_ids"]),
             rejected=input["rejected"],
             partial_state=input.get("partial_state", False),
         )
@@ -349,7 +355,7 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
     _storage: "StorageControllers"
     state_group_before_event: Optional[int]
     state_group_after_event: Optional[int]
-    state_delta_due_to_event: Optional[dict]
+    state_delta_due_to_event: Optional[StateMap[str]]
     prev_group_for_state_group_before_event: Optional[int]
     delta_ids_to_state_group_before_event: Optional[StateMap[str]]
     partial_state: bool
@@ -380,26 +386,16 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
 
         events_and_persisted_context = []
         for event, unpersisted_context in amended_events_and_context:
-            if event.is_state():
-                context = EventContext(
-                    storage=unpersisted_context._storage,
-                    state_group=unpersisted_context.state_group_after_event,
-                    state_group_before_event=unpersisted_context.state_group_before_event,
-                    state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
-                    partial_state=unpersisted_context.partial_state,
-                    prev_group=unpersisted_context.state_group_before_event,
-                    delta_ids=unpersisted_context.state_delta_due_to_event,
-                )
-            else:
-                context = EventContext(
-                    storage=unpersisted_context._storage,
-                    state_group=unpersisted_context.state_group_after_event,
-                    state_group_before_event=unpersisted_context.state_group_before_event,
-                    state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
-                    partial_state=unpersisted_context.partial_state,
-                    prev_group=unpersisted_context.prev_group_for_state_group_before_event,
-                    delta_ids=unpersisted_context.delta_ids_to_state_group_before_event,
-                )
+            state_group_deltas = unpersisted_context._build_state_group_deltas()
+
+            context = EventContext(
+                storage=unpersisted_context._storage,
+                state_group=unpersisted_context.state_group_after_event,
+                state_group_before_event=unpersisted_context.state_group_before_event,
+                state_delta_due_to_event=unpersisted_context.state_delta_due_to_event,
+                partial_state=unpersisted_context.partial_state,
+                state_group_deltas=state_group_deltas,
+            )
             events_and_persisted_context.append((event, context))
         return events_and_persisted_context
 
@@ -452,11 +448,11 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
 
         # if the event isn't a state event the state group doesn't change
         if not self.state_delta_due_to_event:
-            state_group_after_event = self.state_group_before_event
+            self.state_group_after_event = self.state_group_before_event
 
         # otherwise if it is a state event we need to get a state group for it
         else:
-            state_group_after_event = await self._storage.state.store_state_group(
+            self.state_group_after_event = await self._storage.state.store_state_group(
                 event.event_id,
                 event.room_id,
                 prev_group=self.state_group_before_event,
@@ -464,16 +460,81 @@ class UnpersistedEventContext(UnpersistedEventContextBase):
                 current_state_ids=None,
             )
 
+        state_group_deltas = self._build_state_group_deltas()
+
         return EventContext.with_state(
             storage=self._storage,
-            state_group=state_group_after_event,
+            state_group=self.state_group_after_event,
             state_group_before_event=self.state_group_before_event,
             state_delta_due_to_event=self.state_delta_due_to_event,
+            state_group_deltas=state_group_deltas,
             partial_state=self.partial_state,
-            prev_group=self.state_group_before_event,
-            delta_ids=self.state_delta_due_to_event,
         )
 
+    def _build_state_group_deltas(self) -> Dict[Tuple[int, int], StateMap]:
+        """
+        Collect deltas between the state groups associated with this context
+        """
+        state_group_deltas = {}
+
+        # if we know the state group before the event and after the event, add them and the
+        # state delta between them to state_group_deltas
+        if self.state_group_before_event and self.state_group_after_event:
+            # if we have the state groups we should have the delta
+            assert self.state_delta_due_to_event is not None
+            state_group_deltas[
+                (
+                    self.state_group_before_event,
+                    self.state_group_after_event,
+                )
+            ] = self.state_delta_due_to_event
+
+        # the state group before the event may also have a state group which precedes it, if
+        # we have that and the state group before the event, add them and the state
+        # delta between them to state_group_deltas
+        if (
+            self.prev_group_for_state_group_before_event
+            and self.state_group_before_event
+        ):
+            # if we have both state groups we should have the delta between them
+            assert self.delta_ids_to_state_group_before_event is not None
+            state_group_deltas[
+                (
+                    self.prev_group_for_state_group_before_event,
+                    self.state_group_before_event,
+                )
+            ] = self.delta_ids_to_state_group_before_event
+
+        return state_group_deltas
+
+
+def _encode_state_group_delta(
+    state_group_delta: Dict[Tuple[int, int], StateMap[str]]
+) -> List[Tuple[int, int, Optional[List[Tuple[str, str, str]]]]]:
+    if not state_group_delta:
+        return []
+
+    state_group_delta_encoded = []
+    for key, value in state_group_delta.items():
+        state_group_delta_encoded.append((key[0], key[1], _encode_state_dict(value)))
+
+    return state_group_delta_encoded
+
+
+def _decode_state_group_delta(
+    input: List[Tuple[int, int, List[Tuple[str, str, str]]]]
+) -> Dict[Tuple[int, int], StateMap[str]]:
+    if not input:
+        return {}
+
+    state_group_deltas = {}
+    for state_group_1, state_group_2, state_dict in input:
+        state_map = _decode_state_dict(state_dict)
+        assert state_map is not None
+        state_group_deltas[(state_group_1, state_group_2)] = state_map
+
+    return state_group_deltas
+
 
 def _encode_state_dict(
     state_dict: Optional[StateMap[str]],
diff --git a/synapse/federation/federation_client.py b/synapse/federation/federation_client.py
index a2cf3a96c6..e5359ca558 100644
--- a/synapse/federation/federation_client.py
+++ b/synapse/federation/federation_client.py
@@ -260,7 +260,9 @@ class FederationClient(FederationBase):
         use_unstable = False
         for user_id, one_time_keys in query.items():
             for device_id, algorithms in one_time_keys.items():
-                if any(count > 1 for count in algorithms.values()):
+                # If more than one algorithm is requested, attempt to use the unstable
+                # endpoint.
+                if sum(algorithms.values()) > 1:
                     use_unstable = True
                 if algorithms:
                     # For the stable query, choose only the first algorithm.
@@ -296,6 +298,7 @@ class FederationClient(FederationBase):
         else:
             logger.debug("Skipping unstable claim client keys API")
 
+        # TODO Potentially attempt multiple queries and combine the results?
         return await self.transport_layer.claim_client_keys(
             user, destination, content, timeout
         )
diff --git a/synapse/federation/federation_server.py b/synapse/federation/federation_server.py
index 9425b32507..61fa3b30af 100644
--- a/synapse/federation/federation_server.py
+++ b/synapse/federation/federation_server.py
@@ -1016,7 +1016,9 @@ class FederationServer(FederationBase):
             for user_id, device_keys in result.items():
                 for device_id, keys in device_keys.items():
                     for key_id, key in keys.items():
-                        json_result.setdefault(user_id, {})[device_id] = {key_id: key}
+                        json_result.setdefault(user_id, {}).setdefault(device_id, {})[
+                            key_id
+                        ] = key
 
         logger.info(
             "Claimed one-time-keys: %s",
diff --git a/synapse/handlers/pagination.py b/synapse/handlers/pagination.py
index d5257acb7d..19b8728db9 100644
--- a/synapse/handlers/pagination.py
+++ b/synapse/handlers/pagination.py
@@ -40,6 +40,11 @@ if TYPE_CHECKING:
 
 logger = logging.getLogger(__name__)
 
+# How many single event gaps we tolerate returning in a `/messages` response before we
+# backfill and try to fill in the history. This is an arbitrarily picked number so feel
+# free to tune it in the future.
+BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD = 3
+
 
 @attr.s(slots=True, auto_attribs=True)
 class PurgeStatus:
@@ -486,35 +491,35 @@ class PaginationHandler:
                         room_id, room_token.stream
                     )
 
-                if not use_admin_priviledge and membership == Membership.LEAVE:
-                    # If they have left the room then clamp the token to be before
-                    # they left the room, to save the effort of loading from the
-                    # database.
-
-                    # This is only None if the room is world_readable, in which
-                    # case "JOIN" would have been returned.
-                    assert member_event_id
+            # If they have left the room then clamp the token to be before
+            # they left the room, to save the effort of loading from the
+            # database.
+            if (
+                pagin_config.direction == Direction.BACKWARDS
+                and not use_admin_priviledge
+                and membership == Membership.LEAVE
+            ):
+                # This is only None if the room is world_readable, in which case
+                # "Membership.JOIN" would have been returned and we should never hit
+                # this branch.
+                assert member_event_id
+
+                leave_token = await self.store.get_topological_token_for_event(
+                    member_event_id
+                )
+                assert leave_token.topological is not None
 
-                    leave_token = await self.store.get_topological_token_for_event(
-                        member_event_id
+                if leave_token.topological < curr_topo:
+                    from_token = from_token.copy_and_replace(
+                        StreamKeyType.ROOM, leave_token
                     )
-                    assert leave_token.topological is not None
-
-                    if leave_token.topological < curr_topo:
-                        from_token = from_token.copy_and_replace(
-                            StreamKeyType.ROOM, leave_token
-                        )
-
-                await self.hs.get_federation_handler().maybe_backfill(
-                    room_id,
-                    curr_topo,
-                    limit=pagin_config.limit,
-                )
 
             to_room_key = None
             if pagin_config.to_token:
                 to_room_key = pagin_config.to_token.room_key
 
+            # Initially fetch the events from the database. With any luck, we can return
+            # these without blocking on backfill (handled below).
             events, next_key = await self.store.paginate_room_events(
                 room_id=room_id,
                 from_key=from_token.room_key,
@@ -524,6 +529,94 @@ class PaginationHandler:
                 event_filter=event_filter,
             )
 
+            if pagin_config.direction == Direction.BACKWARDS:
+                # We use a `Set` because there can be multiple events at a given depth
+                # and we only care about looking at the unique continum of depths to
+                # find gaps.
+                event_depths: Set[int] = {event.depth for event in events}
+                sorted_event_depths = sorted(event_depths)
+
+                # Inspect the depths of the returned events to see if there are any gaps
+                found_big_gap = False
+                number_of_gaps = 0
+                previous_event_depth = (
+                    sorted_event_depths[0] if len(sorted_event_depths) > 0 else 0
+                )
+                for event_depth in sorted_event_depths:
+                    # We don't expect a negative depth but we'll just deal with it in
+                    # any case by taking the absolute value to get the true gap between
+                    # any two integers.
+                    depth_gap = abs(event_depth - previous_event_depth)
+                    # A `depth_gap` of 1 is a normal continuous chain to the next event
+                    # (1 <-- 2 <-- 3) so anything larger indicates a missing event (it's
+                    # also possible there is no event at a given depth but we can't ever
+                    # know that for sure)
+                    if depth_gap > 1:
+                        number_of_gaps += 1
+
+                    # We only tolerate a small number single-event long gaps in the
+                    # returned events because those are most likely just events we've
+                    # failed to pull in the past. Anything longer than that is probably
+                    # a sign that we're missing a decent chunk of history and we should
+                    # try to backfill it.
+                    #
+                    # XXX: It's possible we could tolerate longer gaps if we checked
+                    # that a given events `prev_events` is one that has failed pull
+                    # attempts and we could just treat it like a dead branch of history
+                    # for now or at least something that we don't need the block the
+                    # client on to try pulling.
+                    #
+                    # XXX: If we had something like MSC3871 to indicate gaps in the
+                    # timeline to the client, we could also get away with any sized gap
+                    # and just have the client refetch the holes as they see fit.
+                    if depth_gap > 2:
+                        found_big_gap = True
+                        break
+                    previous_event_depth = event_depth
+
+                # Backfill in the foreground if we found a big gap, have too many holes,
+                # or we don't have enough events to fill the limit that the client asked
+                # for.
+                missing_too_many_events = (
+                    number_of_gaps > BACKFILL_BECAUSE_TOO_MANY_GAPS_THRESHOLD
+                )
+                not_enough_events_to_fill_response = len(events) < pagin_config.limit
+                if (
+                    found_big_gap
+                    or missing_too_many_events
+                    or not_enough_events_to_fill_response
+                ):
+                    did_backfill = (
+                        await self.hs.get_federation_handler().maybe_backfill(
+                            room_id,
+                            curr_topo,
+                            limit=pagin_config.limit,
+                        )
+                    )
+
+                    # If we did backfill something, refetch the events from the database to
+                    # catch anything new that might have been added since we last fetched.
+                    if did_backfill:
+                        events, next_key = await self.store.paginate_room_events(
+                            room_id=room_id,
+                            from_key=from_token.room_key,
+                            to_key=to_room_key,
+                            direction=pagin_config.direction,
+                            limit=pagin_config.limit,
+                            event_filter=event_filter,
+                        )
+                else:
+                    # Otherwise, we can backfill in the background for eventual
+                    # consistency's sake but we don't need to block the client waiting
+                    # for a costly federation call and processing.
+                    run_as_background_process(
+                        "maybe_backfill_in_the_background",
+                        self.hs.get_federation_handler().maybe_backfill,
+                        room_id,
+                        curr_topo,
+                        limit=pagin_config.limit,
+                    )
+
             next_token = from_token.copy_and_replace(StreamKeyType.ROOM, next_key)
 
         # if no events are returned from pagination, that implies
diff --git a/synapse/storage/controllers/persist_events.py b/synapse/storage/controllers/persist_events.py
index f1d2c71c91..35c0680365 100644
--- a/synapse/storage/controllers/persist_events.py
+++ b/synapse/storage/controllers/persist_events.py
@@ -839,9 +839,8 @@ class EventsPersistenceStorageController:
                         "group" % (ev.event_id,)
                     )
                 continue
-
-            if ctx.prev_group:
-                state_group_deltas[(ctx.prev_group, ctx.state_group)] = ctx.delta_ids
+            if ctx.state_group_deltas:
+                state_group_deltas.update(ctx.state_group_deltas)
 
         # We need to map the event_ids to their state groups. First, let's
         # check if the event is one we're persisting, in which case we can
diff --git a/synapse/storage/databases/main/events.py b/synapse/storage/databases/main/events.py
index e2e6eb479f..44af3357af 100644
--- a/synapse/storage/databases/main/events.py
+++ b/synapse/storage/databases/main/events.py
@@ -1729,13 +1729,22 @@ class PersistEventsStore:
             if not row["rejects"] and not row["redacts"]:
                 to_prefill.append(EventCacheEntry(event=event, redacted_event=None))
 
-        async def prefill() -> None:
+        async def external_prefill() -> None:
             for cache_entry in to_prefill:
-                await self.store._get_event_cache.set(
+                await self.store._get_event_cache.set_external(
                     (cache_entry.event.event_id,), cache_entry
                 )
 
-        txn.async_call_after(prefill)
+        def local_prefill() -> None:
+            for cache_entry in to_prefill:
+                self.store._get_event_cache.set_local(
+                    (cache_entry.event.event_id,), cache_entry
+                )
+
+        # The order these are called here is not as important as knowing that after the
+        # transaction is finished, the async_call_after will run before the call_after.
+        txn.async_call_after(external_prefill)
+        txn.call_after(local_prefill)
 
     def _store_redaction(self, txn: LoggingTransaction, event: EventBase) -> None:
         assert event.redacts is not None
diff --git a/synapse/storage/databases/main/events_worker.py b/synapse/storage/databases/main/events_worker.py
index d93ffc4efa..7e7648c951 100644
--- a/synapse/storage/databases/main/events_worker.py
+++ b/synapse/storage/databases/main/events_worker.py
@@ -883,7 +883,7 @@ class EventsWorkerStore(SQLBaseStore):
 
     async def _invalidate_async_get_event_cache(self, event_id: str) -> None:
         """
-        Invalidates an event in the asyncronous get event cache, which may be remote.
+        Invalidates an event in the asynchronous get event cache, which may be remote.
 
         Arguments:
             event_id: the event ID to invalidate
diff --git a/synapse/util/__init__.py b/synapse/util/__init__.py
index 7ea0c4c36b..9f3b8741c1 100644
--- a/synapse/util/__init__.py
+++ b/synapse/util/__init__.py
@@ -116,6 +116,11 @@ class Clock:
 
         Waits `msec` initially before calling `f` for the first time.
 
+        If the function given to `looping_call` returns an awaitable/deferred, the next
+        call isn't scheduled until after the returned awaitable has finished. We get
+        this functionality thanks to this function being a thin wrapper around
+        `twisted.internet.task.LoopingCall`.
+
         Note that the function will be called with no logcontext, so if it is anything
         other than trivial, you probably want to wrap it in run_as_background_process.
 
diff --git a/synapse/util/caches/lrucache.py b/synapse/util/caches/lrucache.py
index 6137c85e10..be6554319a 100644
--- a/synapse/util/caches/lrucache.py
+++ b/synapse/util/caches/lrucache.py
@@ -842,7 +842,13 @@ class AsyncLruCache(Generic[KT, VT]):
         return self._lru_cache.get(key, update_metrics=update_metrics)
 
     async def set(self, key: KT, value: VT) -> None:
-        self._lru_cache.set(key, value)
+        # This will add the entries in the correct order, local first external second
+        self.set_local(key, value)
+        await self.set_external(key, value)
+
+    async def set_external(self, key: KT, value: VT) -> None:
+        # This method should add an entry to any configured external cache, in this case noop.
+        pass
 
     def set_local(self, key: KT, value: VT) -> None:
         self._lru_cache.set(key, value)
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 6687c28e8f..b5e42f9600 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -101,8 +101,7 @@ class TestEventContext(unittest.HomeserverTestCase):
         self.assertEqual(
             context.state_group_before_event, d_context.state_group_before_event
         )
-        self.assertEqual(context.prev_group, d_context.prev_group)
-        self.assertEqual(context.delta_ids, d_context.delta_ids)
+        self.assertEqual(context.state_group_deltas, d_context.state_group_deltas)
         self.assertEqual(context.app_service, d_context.app_service)
 
         self.assertEqual(
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 788500e38f..b223dc750b 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -139,6 +139,55 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
             # That should result in a single db query to lookup
             self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
 
+    def test_persisting_event_prefills_get_event_cache(self) -> None:
+        """
+        Test to make sure that the `_get_event_cache` is prefilled after we persist an
+        event and returns the updated value.
+        """
+        event, event_context = self.get_success(
+            create_event(
+                self.hs,
+                room_id=self.room_id,
+                sender=self.user,
+                type="test_event_type",
+                content={"body": "conflabulation"},
+            )
+        )
+
+        # First, check `_get_event_cache` for the event we just made
+        # to verify it's not in the cache.
+        res = self.store._get_event_cache.get_local((event.event_id,))
+        self.assertEqual(res, None, "Event was cached when it should not have been.")
+
+        with LoggingContext(name="test") as ctx:
+            # Persist the event which should invalidate then prefill the
+            # `_get_event_cache` so we don't return stale values.
+            # Side Note: Apparently, persisting an event isn't a transaction in the
+            # sense that it is recorded in the LoggingContext
+            persistence = self.hs.get_storage_controllers().persistence
+            assert persistence is not None
+            self.get_success(
+                persistence.persist_event(
+                    event,
+                    event_context,
+                )
+            )
+
+            # Check `_get_event_cache` again and we should see the updated fact
+            # that we now have the event cached after persisting it.
+            res = self.store._get_event_cache.get_local((event.event_id,))
+            self.assertEqual(res.event, event, "Event not cached as expected.")  # type: ignore
+
+            # Try and fetch the event from the database.
+            self.get_success(self.store.get_event(event.event_id))
+
+            # Verify that the database hit was avoided.
+            self.assertEqual(
+                ctx.get_resource_usage().evt_db_fetch_count,
+                0,
+                "Database was hit, which would not happen if event was cached.",
+            )
+
     def test_invalidate_cache_by_room_id(self) -> None:
         """
         Test to make sure that all events associated with the given `(room_id,)`
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index e39b63edac..48ebfadaab 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -401,7 +401,10 @@ class EventChainStoreTestCase(HomeserverTestCase):
             assert persist_events_store is not None
             persist_events_store._store_event_txn(
                 txn,
-                [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
+                [
+                    (e, EventContext(self.hs.get_storage_controllers(), {}))
+                    for e in events
+                ],
             )
 
             # Actually call the function that calculates the auth chain stuff.
diff --git a/tests/test_state.py b/tests/test_state.py
index 7a49b87953..eded38c766 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -555,10 +555,15 @@ class StateTestCase(unittest.TestCase):
             (e.event_id for e in old_state + [event]), current_state_ids.values()
         )
 
-        self.assertIsNotNone(context.state_group_before_event)
+        assert context.state_group_before_event is not None
+        assert context.state_group is not None
+        self.assertEqual(
+            context.state_group_deltas.get(
+                (context.state_group_before_event, context.state_group)
+            ),
+            {(event.type, event.state_key): event.event_id},
+        )
         self.assertNotEqual(context.state_group_before_event, context.state_group)
-        self.assertEqual(context.state_group_before_event, context.prev_group)
-        self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
 
     @defer.inlineCallbacks
     def test_trivial_annotate_message(