summary refs log tree commit diff
path: root/tests/unittest.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/unittest.py')
-rw-r--r--tests/unittest.py15
1 files changed, 10 insertions, 5 deletions
diff --git a/tests/unittest.py b/tests/unittest.py
index 2d73911747..4b31f84494 100644
--- a/tests/unittest.py
+++ b/tests/unittest.py
@@ -45,7 +45,7 @@ from typing_extensions import Concatenate, ParamSpec, Protocol
 from twisted.internet.defer import Deferred, ensureDeferred
 from twisted.python.failure import Failure
 from twisted.python.threadpool import ThreadPool
-from twisted.test.proto_helpers import MemoryReactor
+from twisted.test.proto_helpers import MemoryReactor, MemoryReactorClock
 from twisted.trial import unittest
 from twisted.web.resource import Resource
 from twisted.web.server import Request
@@ -82,7 +82,7 @@ from tests.server import (
 )
 from tests.test_utils import event_injection, setup_awaitable_errors
 from tests.test_utils.logging_setup import setup_logging
-from tests.utils import default_config, setupdb
+from tests.utils import checked_cast, default_config, setupdb
 
 setupdb()
 setup_logging()
@@ -296,7 +296,12 @@ class HomeserverTestCase(TestCase):
 
         from tests.rest.client.utils import RestHelper
 
-        self.helper = RestHelper(self.hs, self.site, getattr(self, "user_id", None))
+        self.helper = RestHelper(
+            self.hs,
+            checked_cast(MemoryReactorClock, self.hs.get_reactor()),
+            self.site,
+            getattr(self, "user_id", None),
+        )
 
         if hasattr(self, "user_id"):
             if self.hijack_auth:
@@ -718,7 +723,7 @@ class HomeserverTestCase(TestCase):
         event_creator = self.hs.get_event_creation_handler()
         requester = create_requester(user)
 
-        event, context, _ = self.get_success(
+        event, unpersisted_context, _ = self.get_success(
             event_creator.create_event(
                 requester,
                 {
@@ -730,7 +735,7 @@ class HomeserverTestCase(TestCase):
                 prev_event_ids=prev_event_ids,
             )
         )
-
+        context = self.get_success(unpersisted_context.persist(event))
         if soft_failed:
             event.internal_metadata.soft_failed = True