diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 6687c28e8f..b5e42f9600 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -101,8 +101,7 @@ class TestEventContext(unittest.HomeserverTestCase):
self.assertEqual(
context.state_group_before_event, d_context.state_group_before_event
)
- self.assertEqual(context.prev_group, d_context.prev_group)
- self.assertEqual(context.delta_ids, d_context.delta_ids)
+ self.assertEqual(context.state_group_deltas, d_context.state_group_deltas)
self.assertEqual(context.app_service, d_context.app_service)
self.assertEqual(
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 788500e38f..b223dc750b 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -139,6 +139,55 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
# That should result in a single db query to lookup
self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
+ def test_persisting_event_prefills_get_event_cache(self) -> None:
+ """
+ Test to make sure that the `_get_event_cache` is prefilled after we persist an
+ event and returns the updated value.
+ """
+ event, event_context = self.get_success(
+ create_event(
+ self.hs,
+ room_id=self.room_id,
+ sender=self.user,
+ type="test_event_type",
+ content={"body": "conflabulation"},
+ )
+ )
+
+ # First, check `_get_event_cache` for the event we just made
+ # to verify it's not in the cache.
+ res = self.store._get_event_cache.get_local((event.event_id,))
+ self.assertEqual(res, None, "Event was cached when it should not have been.")
+
+ with LoggingContext(name="test") as ctx:
+ # Persist the event which should invalidate then prefill the
+ # `_get_event_cache` so we don't return stale values.
+ # Side Note: Apparently, persisting an event isn't a transaction in the
+ # sense that it is recorded in the LoggingContext
+ persistence = self.hs.get_storage_controllers().persistence
+ assert persistence is not None
+ self.get_success(
+ persistence.persist_event(
+ event,
+ event_context,
+ )
+ )
+
+ # Check `_get_event_cache` again and we should see the updated fact
+ # that we now have the event cached after persisting it.
+ res = self.store._get_event_cache.get_local((event.event_id,))
+ self.assertEqual(res.event, event, "Event not cached as expected.") # type: ignore
+
+ # Try and fetch the event from the database.
+ self.get_success(self.store.get_event(event.event_id))
+
+ # Verify that the database hit was avoided.
+ self.assertEqual(
+ ctx.get_resource_usage().evt_db_fetch_count,
+ 0,
+ "Database was hit, which would not happen if event was cached.",
+ )
+
def test_invalidate_cache_by_room_id(self) -> None:
"""
Test to make sure that all events associated with the given `(room_id,)`
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index e39b63edac..48ebfadaab 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -401,7 +401,10 @@ class EventChainStoreTestCase(HomeserverTestCase):
assert persist_events_store is not None
persist_events_store._store_event_txn(
txn,
- [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
+ [
+ (e, EventContext(self.hs.get_storage_controllers(), {}))
+ for e in events
+ ],
)
# Actually call the function that calculates the auth chain stuff.
diff --git a/tests/test_state.py b/tests/test_state.py
index 7a49b87953..eded38c766 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -555,10 +555,15 @@ class StateTestCase(unittest.TestCase):
(e.event_id for e in old_state + [event]), current_state_ids.values()
)
- self.assertIsNotNone(context.state_group_before_event)
+ assert context.state_group_before_event is not None
+ assert context.state_group is not None
+ self.assertEqual(
+ context.state_group_deltas.get(
+ (context.state_group_before_event, context.state_group)
+ ),
+ {(event.type, event.state_key): event.event_id},
+ )
self.assertNotEqual(context.state_group_before_event, context.state_group)
- self.assertEqual(context.state_group_before_event, context.prev_group)
- self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
@defer.inlineCallbacks
def test_trivial_annotate_message(
|