diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 788500e38f..b223dc750b 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -139,6 +139,55 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# That should result in a single db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+ def test_persisting_event_prefills_get_event_cache(self) -> None:
+ """
+ Test to make sure that the `_get_event_cache` is prefilled after we persist an
+ event and returns the updated value.
+ """
+ event, event_context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ sender=self.user,
+ type="test_event_type",
+ content={"body": "conflabulation"},
+ )
+ )
+
+ # First, check `_get_event_cache` for the event we just made
+ # to verify it's not in the cache.
+ res = self.store._get_event_cache.get_local((event.event_id,))
+ self.assertEqual(res, None, "Event was cached when it should not have been.")
+
+ with LoggingContext(name="test") as ctx:
+ # Persist the event which should invalidate then prefill the
+ # `_get_event_cache` so we don't return stale values.
+ # Side Note: Apparently, persisting an event isn't a transaction in the
+ # sense that it is recorded in the LoggingContext
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.get_success(
+ persistence.persist_event(
+ event,
+ event_context,
+ )
+ )
+
+ # Check `_get_event_cache` again and we should see the updated fact
+ # that we now have the event cached after persisting it.
+ res = self.store._get_event_cache.get_local((event.event_id,))
+ self.assertEqual(res.event, event, "Event not cached as expected.") # type: ignore
+
+ # Try and fetch the event from the database.
+ self.get_success(self.store.get_event(event.event_id))
+
+ # Verify that the database hit was avoided.
+ self.assertEqual(
+ ctx.get_resource_usage().evt_db_fetch_count,
+ 0,
+ "Database was hit, which would not happen if event was cached.",
+ )
+
def test_invalidate_cache_by_room_id(self) -> None:
"""
Test to make sure that all events associated with the given `(room_id,)`
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index e39b63edac..48ebfadaab 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -401,7 +401,10 @@ class EventChainStoreTestCase(HomeserverTestCase):
assert persist_events_store is not None
persist_events_store._store_event_txn(
txn,
- [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
+ [
+ (e, EventContext(self.hs.get_storage_controllers(), {}))
+ for e in events
+ ],
)
# Actually call the function that calculates the auth chain stuff.
diff --git a/tests/storage/test_event_federation.py b/tests/storage/test_event_federation.py
index 4b8d8328d7..0f3b0744f1 100644
--- a/tests/storage/test_event_federation.py
+++ b/tests/storage/test_event_federation.py
@@ -20,7 +20,6 @@ from parameterized import parameterized
from twisted.test.proto_helpers import MemoryReactor
-from synapse.api.constants import EventTypes
from synapse.api.room_versions import (
KNOWN_ROOM_VERSIONS,
EventFormatVersions,
@@ -924,216 +923,6 @@ class EventFederationWorkerStoreTestCase(tests.unittest.HomeserverTestCase):
backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
self.assertEqual(backfill_event_ids, ["b3", "b2", "b1"])
- def _setup_room_for_insertion_backfill_tests(self) -> _BackfillSetupInfo:
- """
- Sets up a room with various insertion event backward extremities to test
- backfill functions against.
-
- Returns:
- _BackfillSetupInfo including the `room_id` to test against and
- `depth_map` of events in the room
- """
- room_id = "!backfill-room-test:some-host"
-
- depth_map: Dict[str, int] = {
- "1": 1,
- "2": 2,
- "insertion_eventA": 3,
- "3": 4,
- "insertion_eventB": 5,
- "4": 6,
- "5": 7,
- }
-
- def populate_db(txn: LoggingTransaction) -> None:
- # Insert the room to satisfy the foreign key constraint of
- # `event_failed_pull_attempts`
- self.store.db_pool.simple_insert_txn(
- txn,
- "rooms",
- {
- "room_id": room_id,
- "creator": "room_creator_user_id",
- "is_public": True,
- "room_version": "6",
- },
- )
-
- # Insert our server events
- stream_ordering = 0
- for event_id, depth in depth_map.items():
- self.store.db_pool.simple_insert_txn(
- txn,
- table="events",
- values={
- "event_id": event_id,
- "type": EventTypes.MSC2716_INSERTION
- if event_id.startswith("insertion_event")
- else "test_regular_type",
- "room_id": room_id,
- "depth": depth,
- "topological_ordering": depth,
- "stream_ordering": stream_ordering,
- "processed": True,
- "outlier": False,
- },
- )
-
- if event_id.startswith("insertion_event"):
- self.store.db_pool.simple_insert_txn(
- txn,
- table="insertion_event_extremities",
- values={
- "event_id": event_id,
- "room_id": room_id,
- },
- )
-
- stream_ordering += 1
-
- self.get_success(
- self.store.db_pool.runInteraction(
- "_setup_room_for_insertion_backfill_tests_populate_db",
- populate_db,
- )
- )
-
- return _BackfillSetupInfo(room_id=room_id, depth_map=depth_map)
-
- def test_get_insertion_event_backward_extremities_in_room(self) -> None:
- """
- Test to make sure only insertion event backward extremities that are
- older and come before the `current_depth` are returned.
- """
- setup_info = self._setup_room_for_insertion_backfill_tests()
- room_id = setup_info.room_id
- depth_map = setup_info.depth_map
-
- # Try at "insertion_eventB"
- backfill_points = self.get_success(
- self.store.get_insertion_event_backward_extremities_in_room(
- room_id, depth_map["insertion_eventB"], limit=100
- )
- )
- backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
- self.assertEqual(backfill_event_ids, ["insertion_eventB", "insertion_eventA"])
-
- # Try at "insertion_eventA"
- backfill_points = self.get_success(
- self.store.get_insertion_event_backward_extremities_in_room(
- room_id, depth_map["insertion_eventA"], limit=100
- )
- )
- backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
- # Event "2" has a depth of 2 but is not included here because we only
- # know the approximate depth of 5 from our event "3".
- self.assertListEqual(backfill_event_ids, ["insertion_eventA"])
-
- def test_get_insertion_event_backward_extremities_in_room_excludes_events_we_have_attempted(
- self,
- ) -> None:
- """
- Test to make sure that insertion events we have attempted to backfill
- (and within backoff timeout duration) do not show up as an event to
- backfill again.
- """
- setup_info = self._setup_room_for_insertion_backfill_tests()
- room_id = setup_info.room_id
- depth_map = setup_info.depth_map
-
- # Record some attempts to backfill these events which will make
- # `get_insertion_event_backward_extremities_in_room` exclude them
- # because we haven't passed the backoff interval.
- self.get_success(
- self.store.record_event_failed_pull_attempt(
- room_id, "insertion_eventA", "fake cause"
- )
- )
-
- # No time has passed since we attempted to backfill ^
-
- # Try at "insertion_eventB"
- backfill_points = self.get_success(
- self.store.get_insertion_event_backward_extremities_in_room(
- room_id, depth_map["insertion_eventB"], limit=100
- )
- )
- backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
- # Only the backfill points that we didn't record earlier exist here.
- self.assertEqual(backfill_event_ids, ["insertion_eventB"])
-
- def test_get_insertion_event_backward_extremities_in_room_attempted_event_retry_after_backoff_duration(
- self,
- ) -> None:
- """
- Test to make sure after we fake attempt to backfill event
- "insertion_eventA" many times, we can see retry and see the
- "insertion_eventA" again after the backoff timeout duration has
- exceeded.
- """
- setup_info = self._setup_room_for_insertion_backfill_tests()
- room_id = setup_info.room_id
- depth_map = setup_info.depth_map
-
- # Record some attempts to backfill these events which will make
- # `get_backfill_points_in_room` exclude them because we
- # haven't passed the backoff interval.
- self.get_success(
- self.store.record_event_failed_pull_attempt(
- room_id, "insertion_eventB", "fake cause"
- )
- )
- self.get_success(
- self.store.record_event_failed_pull_attempt(
- room_id, "insertion_eventA", "fake cause"
- )
- )
- self.get_success(
- self.store.record_event_failed_pull_attempt(
- room_id, "insertion_eventA", "fake cause"
- )
- )
- self.get_success(
- self.store.record_event_failed_pull_attempt(
- room_id, "insertion_eventA", "fake cause"
- )
- )
- self.get_success(
- self.store.record_event_failed_pull_attempt(
- room_id, "insertion_eventA", "fake cause"
- )
- )
-
- # Now advance time by 2 hours and we should only be able to see
- # "insertion_eventB" because we have waited long enough for the single
- # attempt (2^1 hours) but we still shouldn't see "insertion_eventA"
- # because we haven't waited long enough for this many attempts.
- self.reactor.advance(datetime.timedelta(hours=2).total_seconds())
-
- # Try at "insertion_eventA" and make sure that "insertion_eventA" is not
- # in the list because we've already attempted many times
- backfill_points = self.get_success(
- self.store.get_insertion_event_backward_extremities_in_room(
- room_id, depth_map["insertion_eventA"], limit=100
- )
- )
- backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
- self.assertEqual(backfill_event_ids, [])
-
- # Now advance time by 20 hours (above 2^4 because we made 4 attemps) and
- # see if we can now backfill it
- self.reactor.advance(datetime.timedelta(hours=20).total_seconds())
-
- # Try at "insertion_eventA" again after we advanced enough time and we
- # should see "insertion_eventA" again
- backfill_points = self.get_success(
- self.store.get_insertion_event_backward_extremities_in_room(
- room_id, depth_map["insertion_eventA"], limit=100
- )
- )
- backfill_event_ids = [backfill_point[0] for backfill_point in backfill_points]
- self.assertEqual(backfill_event_ids, ["insertion_eventA"])
-
def test_get_event_ids_with_failed_pull_attempts(self) -> None:
"""
Test to make sure we properly get event_ids based on whether they have any
|