summary refs log tree commit diff
path: root/tests/storage/test_event_chain.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage/test_event_chain.py')
-rw-r--r--tests/storage/test_event_chain.py10
1 files changed, 8 insertions, 2 deletions
diff --git a/tests/storage/test_event_chain.py b/tests/storage/test_event_chain.py
index c070278db8..a10e5fa8b1 100644
--- a/tests/storage/test_event_chain.py
+++ b/tests/storage/test_event_chain.py
@@ -389,6 +389,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
         """
 
         persist_events_store = self.hs.get_datastores().persist_events
+        assert persist_events_store is not None
 
         for e in events:
             e.internal_metadata.stream_ordering = self._next_stream_ordering
@@ -397,6 +398,7 @@ class EventChainStoreTestCase(HomeserverTestCase):
         def _persist(txn: LoggingTransaction) -> None:
             # We need to persist the events to the events and state_events
             # tables.
+            assert persist_events_store is not None
             persist_events_store._store_event_txn(
                 txn,
                 [(e, EventContext(self.hs.get_storage_controllers())) for e in events],
@@ -540,7 +542,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
                 self.requester, events_and_context=[(event, context)]
             )
         )
-        state1 = set(self.get_success(context.get_current_state_ids()).values())
+        state_ids1 = self.get_success(context.get_current_state_ids())
+        assert state_ids1 is not None
+        state1 = set(state_ids1.values())
 
         event, context = self.get_success(
             event_handler.create_event(
@@ -560,7 +564,9 @@ class EventChainBackgroundUpdateTestCase(HomeserverTestCase):
                 self.requester, events_and_context=[(event, context)]
             )
         )
-        state2 = set(self.get_success(context.get_current_state_ids()).values())
+        state_ids2 = self.get_success(context.get_current_state_ids())
+        assert state_ids2 is not None
+        state2 = set(state_ids2.values())
 
         # Delete the chain cover info.