summary refs log tree commit diff
path: root/tests/handlers/test_message.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/handlers/test_message.py')
-rw-r--r--tests/handlers/test_message.py36
1 files changed, 23 insertions, 13 deletions
diff --git a/tests/handlers/test_message.py b/tests/handlers/test_message.py
index c4727ab917..9691d66b48 100644
--- a/tests/handlers/test_message.py
+++ b/tests/handlers/test_message.py
@@ -18,7 +18,7 @@ from twisted.test.proto_helpers import MemoryReactor
 
 from synapse.api.constants import EventTypes
 from synapse.events import EventBase
-from synapse.events.snapshot import EventContext
+from synapse.events.snapshot import EventContext, UnpersistedEventContextBase
 from synapse.rest import admin
 from synapse.rest.client import login, room
 from synapse.server import HomeServer
@@ -41,20 +41,21 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
     def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
         self.handler = self.hs.get_event_creation_handler()
-        self._persist_event_storage_controller = (
-            self.hs.get_storage_controllers().persistence
-        )
+        persistence = self.hs.get_storage_controllers().persistence
+        assert persistence is not None
+        self._persist_event_storage_controller = persistence
 
         self.user_id = self.register_user("tester", "foobar")
         self.access_token = self.login("tester", "foobar")
         self.room_id = self.helper.create_room_as(self.user_id, tok=self.access_token)
 
-        self.info = self.get_success(
+        info = self.get_success(
             self.hs.get_datastores().main.get_user_by_access_token(
                 self.access_token,
             )
         )
-        self.token_id = self.info.token_id
+        assert info is not None
+        self.token_id = info.token_id
 
         self.requester = create_requester(self.user_id, access_token_id=self.token_id)
 
@@ -78,7 +79,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         return memberEvent, memberEventContext
 
-    def _create_duplicate_event(self, txn_id: str) -> Tuple[EventBase, EventContext]:
+    def _create_duplicate_event(
+        self, txn_id: str
+    ) -> Tuple[EventBase, UnpersistedEventContextBase]:
         """Create a new event with the given transaction ID. All events produced
         by this method will be considered duplicates.
         """
@@ -106,7 +109,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         txn_id = "something_suitably_random"
 
-        event1, context = self._create_duplicate_event(txn_id)
+        event1, unpersisted_context = self._create_duplicate_event(txn_id)
+        context = self.get_success(unpersisted_context.persist(event1))
 
         ret_event1 = self.get_success(
             self.handler.handle_new_client_event(
@@ -118,7 +122,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         self.assertEqual(event1.event_id, ret_event1.event_id)
 
-        event2, context = self._create_duplicate_event(txn_id)
+        event2, unpersisted_context = self._create_duplicate_event(txn_id)
+        context = self.get_success(unpersisted_context.persist(event2))
 
         # We want to test that the deduplication at the persit event end works,
         # so we want to make sure we test with different events.
@@ -139,7 +144,9 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         # Let's test that calling `persist_event` directly also does the right
         # thing.
-        event3, context = self._create_duplicate_event(txn_id)
+        event3, unpersisted_context = self._create_duplicate_event(txn_id)
+        context = self.get_success(unpersisted_context.persist(event3))
+
         self.assertNotEqual(event1.event_id, event3.event_id)
 
         ret_event3, event_pos3, _ = self.get_success(
@@ -153,7 +160,8 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
 
         # Let's test that calling `persist_events` directly also does the right
         # thing.
-        event4, context = self._create_duplicate_event(txn_id)
+        event4, unpersisted_context = self._create_duplicate_event(txn_id)
+        context = self.get_success(unpersisted_context.persist(event4))
         self.assertNotEqual(event1.event_id, event3.event_id)
 
         events, _ = self.get_success(
@@ -173,8 +181,10 @@ class EventCreationTestCase(unittest.HomeserverTestCase):
         txn_id = "something_else_suitably_random"
 
         # Create two duplicate events to persist at the same time
-        event1, context1 = self._create_duplicate_event(txn_id)
-        event2, context2 = self._create_duplicate_event(txn_id)
+        event1, unpersisted_context1 = self._create_duplicate_event(txn_id)
+        context1 = self.get_success(unpersisted_context1.persist(event1))
+        event2, unpersisted_context2 = self._create_duplicate_event(txn_id)
+        context2 = self.get_success(unpersisted_context2.persist(event2))
 
         # Ensure their event IDs are different to start with
         self.assertNotEqual(event1.event_id, event2.event_id)