summary refs log tree commit diff
path: root/tests/storage
diff options
context:
space:
mode:
Diffstat (limited to 'tests/storage')
-rw-r--r--tests/storage/test_redaction.py24
-rw-r--r--tests/storage/test_state.py4
2 files changed, 21 insertions, 7 deletions
diff --git a/tests/storage/test_redaction.py b/tests/storage/test_redaction.py
index df4740f9d9..0100f7da14 100644
--- a/tests/storage/test_redaction.py
+++ b/tests/storage/test_redaction.py
@@ -74,10 +74,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         self.get_success(self._persistence.persist_event(event, context))
 
         return event
@@ -96,10 +98,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         self.get_success(self._persistence.persist_event(event, context))
 
         return event
@@ -119,10 +123,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         self.get_success(self._persistence.persist_event(event, context))
 
         return event
@@ -259,7 +265,7 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             def internal_metadata(self) -> _EventInternalMetadata:
                 return self._base_builder.internal_metadata
 
-        event_1, context_1 = self.get_success(
+        event_1, unpersisted_context_1 = self.get_success(
             self.event_creation_handler.create_new_client_event(
                 cast(
                     EventBuilder,
@@ -280,9 +286,11 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             )
         )
 
+        context_1 = self.get_success(unpersisted_context_1.persist(event_1))
+
         self.get_success(self._persistence.persist_event(event_1, context_1))
 
-        event_2, context_2 = self.get_success(
+        event_2, unpersisted_context_2 = self.get_success(
             self.event_creation_handler.create_new_client_event(
                 cast(
                     EventBuilder,
@@ -302,6 +310,8 @@ class RedactionTestCase(unittest.HomeserverTestCase):
                 )
             )
         )
+
+        context_2 = self.get_success(unpersisted_context_2.persist(event_2))
         self.get_success(self._persistence.persist_event(event_2, context_2))
 
         # fetch one of the redactions
@@ -421,10 +431,12 @@ class RedactionTestCase(unittest.HomeserverTestCase):
             },
         )
 
-        redaction_event, context = self.get_success(
+        redaction_event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(redaction_event))
+
         self.get_success(self._persistence.persist_event(redaction_event, context))
 
         # Now lets jump to the future where we have censored the redaction event
diff --git a/tests/storage/test_state.py b/tests/storage/test_state.py
index bad7f0bc60..f730b888f7 100644
--- a/tests/storage/test_state.py
+++ b/tests/storage/test_state.py
@@ -67,10 +67,12 @@ class StateStoreTestCase(HomeserverTestCase):
             },
         )
 
-        event, context = self.get_success(
+        event, unpersisted_context = self.get_success(
             self.event_creation_handler.create_new_client_event(builder)
         )
 
+        context = self.get_success(unpersisted_context.persist(event))
+
         assert self.storage.persistence is not None
         self.get_success(self.storage.persistence.persist_event(event, context))