summary refs log tree commit diff
path: root/tests
diff options
context:
space:
mode:
Diffstat (limited to 'tests')
-rw-r--r--tests/events/test_snapshot.py3
-rw-r--r--tests/storage/databases/main/test_events_worker.py49
-rw-r--r--tests/storage/test_event_chain.py5
-rw-r--r--tests/test_state.py11
4 files changed, 62 insertions, 6 deletions
diff --git a/tests/events/test_snapshot.py b/tests/events/test_snapshot.py
index 6687c28e8f..b5e42f9600 100644
--- a/tests/events/test_snapshot.py
+++ b/tests/events/test_snapshot.py
@@ -101,8 +101,7 @@ class TestEventContext(unittest.HomeserverTestCase):
         self.assertEqual(
             context.state_group_before_event, d_context.state_group_before_event
         )
-        self.assertEqual(context.prev_group, d_context.prev_group)
-        self.assertEqual(context.delta_ids, d_context.delta_ids)
+        self.assertEqual(context.state_group_deltas, d_context.state_group_deltas)
         self.assertEqual(context.app_service, d_context.app_service)
 
         self.assertEqual(
diff --git a/tests/storage/databases/main/test_events_worker.py b/tests/storage/databases/main/test_events_worker.py
index 788500e38f..b223dc750b 100644
--- a/tests/storage/databases/main/test_events_worker.py
+++ b/tests/storage/databases/main/test_events_worker.py
@@ -139,6 +139,55 @@ class HaveSeenEventsTestCase(unittest.HomeserverTestCase):
             # That should result in a single db query to lookup
             self.assertEqual(ctx.get_resource_usage().db_txn_count, 1)
 
+    def test_persisting_event_prefills_get_event_cache(self) -> None:
+        """
+        Test to make sure that the `_get_event_cache` is prefilled after we persist an
+        event and returns the updated value.
+        """
+        event, event_context = self.get_success(
+            create_event(
+                self.hs,
+                room_id=self.room_id,
+                sender=self.user,
+                type="test_event_type",
+                content={"body": "conflabulation"},
+            )
+        )
+
+        # First, check `_get_event_cache` for the event we just made
+        # to verify it's not in the cache.
+        res = self.store._get_event_cache.get_local((event.event_id,))
+        self.assertEqual(res, None, "Event was cached when it should not have been.")
+
+        with LoggingContext(name="test") as ctx:
+            # Persist the event which should invalidate then prefill the
+            # `_get_event_cache` so we don't return stale values.
+            # Side Note: Apparently, persisting an event isn't a transaction in the
+            # sense that it is recorded in the LoggingContext
+            persistence = self.hs.get_storage_controllers().persistence
+            assert persistence is not None
+            self.get_success(
+                persistence.persist_event(
+                    event,
+                    event_context,
+                )
+            )
+
+            # Check `_get_event_cache` again and we should see the updated fact
+            # that we now have the event cached after persisting it.
+            res = self.store._get_event_cache.get_local((event.event_id,))
+            self.assertEqual(res.event, event, "Event not cached as expected.")  # type: ignore
+
+            # Try and fetch the event from the database.
+            self.get_success(self.store.get_event(event.event_id))
+
+            # Verify that the database hit was avoided.
+            self.assertEqual(
+                ctx.get_resource_usage().evt_db_fetch_count,
+                0,
+                "Database was hit, which would not happen if event was cached.",
+            )
+
     def test_invalidate_cache_by_room_id(self) -> None:
         """
         Test to make sure that all events associated with the given `(room_id,)`
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index e39b63edac..48ebfadaab 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -401,7 +401,10 @@ class EventChainStoreTestCase(HomeserverTestCase):
             assert persist_events_store is not None
             persist_events_store._store_event_txn(
                 txn,
-                [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
+                [
+                    (e, EventContext(self.hs.get_storage_controllers(), {}))
+                    for e in events
+                ],
             )
 
             # Actually call the function that calculates the auth chain stuff.
diff --git a/tests/test_state.py b/tests/test_state.py
index 7a49b87953..eded38c766 100644
--- a/tests/test_state.py
+++ b/tests/test_state.py
@@ -555,10 +555,15 @@ class StateTestCase(unittest.TestCase):
             (e.event_id for e in old_state + [event]), current_state_ids.values()
         )
 
-        self.assertIsNotNone(context.state_group_before_event)
+        assert context.state_group_before_event is not None
+        assert context.state_group is not None
+        self.assertEqual(
+            context.state_group_deltas.get(
+                (context.state_group_before_event, context.state_group)
+            ),
+            {(event.type, event.state_key): event.event_id},
+        )
         self.assertNotEqual(context.state_group_before_event, context.state_group)
-        self.assertEqual(context.state_group_before_event, context.prev_group)
-        self.assertEqual({("state", ""): event.event_id}, context.delta_ids)
 
     @defer.inlineCallbacks
     def test_trivial_annotate_message(