summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8758.misc1
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py6
-rw-r--r--tests/server.py42
3 files changed, 24 insertions, 25 deletions
diff --git a/changelog.d/8758.misc b/changelog.d/8758.misc
new file mode 100644
index 0000000000..54502e9b90
--- /dev/null
+++ b/changelog.d/8758.misc
@@ -0,0 +1 @@
+Refactor test utilities for injecting HTTP requests.
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 6671cbd32d..fbcf8d5b86 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -32,7 +32,7 @@ from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.stringutils import random_string
 
 from tests import unittest
-from tests.server import FakeChannel, wait_until_result
+from tests.server import FakeChannel
 from tests.utils import default_config
 
 
@@ -94,7 +94,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
             % (server_name.encode("utf-8"), key_id.encode("utf-8")),
             b"1.1",
         )
-        wait_until_result(self.reactor, req)
+        channel.await_result()
         self.assertEqual(channel.code, 200)
         resp = channel.json_body
         return resp
@@ -190,7 +190,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
             req.requestReceived(
                 b"POST", path.encode("utf-8"), b"1.1",
             )
-            wait_until_result(self.reactor, req)
+            channel.await_result()
             self.assertEqual(channel.code, 200)
             resp = channel.json_body
             return resp
diff --git a/tests/server.py b/tests/server.py
index ef03109a6c..18cb8b2d72 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -117,6 +117,25 @@ class FakeChannel:
     def transport(self):
         return self
 
+    def await_result(self, timeout: int = 100) -> None:
+        """
+        Wait until the request is finished.
+        """
+        self._reactor.run()
+        x = 0
+
+        while not self.result.get("done"):
+            # If there's a producer, tell it to resume producing so we get content
+            if self._producer:
+                self._producer.resumeProducing()
+
+            x += 1
+
+            if x > timeout:
+                raise TimedOutException("Timed out waiting for request to finish.")
+
+            self._reactor.advance(0.1)
+
 
 class FakeSite:
     """
@@ -225,30 +244,9 @@ def make_request(
     return req, channel
 
 
-def wait_until_result(clock, request, timeout=100):
-    """
-    Wait until the request is finished.
-    """
-    clock.run()
-    x = 0
-
-    while not request.finished:
-
-        # If there's a producer, tell it to resume producing so we get content
-        if request._channel._producer:
-            request._channel._producer.resumeProducing()
-
-        x += 1
-
-        if x > timeout:
-            raise TimedOutException("Timed out waiting for request to finish.")
-
-        clock.advance(0.1)
-
-
 def render(request, resource, clock):
     request.render(resource)
-    wait_until_result(clock, request)
+    request._channel.await_result()
 
 
 @implementer(IReactorPluggableNameResolver)