summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--changelog.d/8751.misc1
-rw-r--r--changelog.d/8758.misc1
-rw-r--r--synapse/handlers/room_member.py36
-rw-r--r--tests/rest/key/v2/test_remote_key_resource.py6
-rw-r--r--tests/server.py42
5 files changed, 44 insertions, 42 deletions
diff --git a/changelog.d/8751.misc b/changelog.d/8751.misc
new file mode 100644
index 0000000000..204c280c0e
--- /dev/null
+++ b/changelog.d/8751.misc
@@ -0,0 +1 @@
+Generalise `RoomMemberHandler._locally_reject_invite` to apply to more flows than just invite.
\ No newline at end of file
diff --git a/changelog.d/8758.misc b/changelog.d/8758.misc
new file mode 100644
index 0000000000..54502e9b90
--- /dev/null
+++ b/changelog.d/8758.misc
@@ -0,0 +1 @@
+Refactor test utilities for injecting HTTP requests.
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 7cd858b7db..fd85e08973 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1104,32 +1104,34 @@ class RoomMemberMasterHandler(RoomMemberHandler):
             #
             logger.warning("Failed to reject invite: %s", e)
 
-            return await self._locally_reject_invite(
+            return await self._generate_local_out_of_band_leave(
                 invite_event, txn_id, requester, content
             )
 
-    async def _locally_reject_invite(
+    async def _generate_local_out_of_band_leave(
         self,
-        invite_event: EventBase,
+        previous_membership_event: EventBase,
         txn_id: Optional[str],
         requester: Requester,
         content: JsonDict,
     ) -> Tuple[str, int]:
-        """Generate a local invite rejection
+        """Generate a local leave event for a room
 
-        This is called after we fail to reject an invite via a remote server. It
-        generates an out-of-band membership event locally.
+        This can be called after we e.g fail to reject an invite via a remote server.
+        It generates an out-of-band membership event locally.
 
         Args:
-            invite_event: the invite to be rejected
+            previous_membership_event: the previous membership event for this user
             txn_id: optional transaction ID supplied by the client
-            requester:  user making the rejection request, according to the access token
-            content: additional content to include in the rejection event.
+            requester: user making the request, according to the access token
+            content: additional content to include in the leave event.
                Normally an empty dict.
-        """
 
-        room_id = invite_event.room_id
-        target_user = invite_event.state_key
+        Returns:
+            A tuple containing (event_id, stream_id of the leave event)
+        """
+        room_id = previous_membership_event.room_id
+        target_user = previous_membership_event.state_key
 
         content["membership"] = Membership.LEAVE
 
@@ -1141,12 +1143,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
             "state_key": target_user,
         }
 
-        # the auth events for the new event are the same as that of the invite, plus
-        # the invite itself.
+        # the auth events for the new event are the same as that of the previous event, plus
+        # the event itself.
         #
-        # the prev_events are just the invite.
-        prev_event_ids = [invite_event.event_id]
-        auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
+        # the prev_events consist solely of the previous membership event.
+        prev_event_ids = [previous_membership_event.event_id]
+        auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids
 
         event, context = await self.event_creation_handler.create_event(
             requester,
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 6671cbd32d..fbcf8d5b86 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -32,7 +32,7 @@ from synapse.util.httpresourcetree import create_resource_tree
 from synapse.util.stringutils import random_string
 
 from tests import unittest
-from tests.server import FakeChannel, wait_until_result
+from tests.server import FakeChannel
 from tests.utils import default_config
 
 
@@ -94,7 +94,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
             % (server_name.encode("utf-8"), key_id.encode("utf-8")),
             b"1.1",
         )
-        wait_until_result(self.reactor, req)
+        channel.await_result()
         self.assertEqual(channel.code, 200)
         resp = channel.json_body
         return resp
@@ -190,7 +190,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
             req.requestReceived(
                 b"POST", path.encode("utf-8"), b"1.1",
             )
-            wait_until_result(self.reactor, req)
+            channel.await_result()
             self.assertEqual(channel.code, 200)
             resp = channel.json_body
             return resp
diff --git a/tests/server.py b/tests/server.py
index 5850eadf3e..5a1583a3e7 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -118,6 +118,25 @@ class FakeChannel:
     def transport(self):
         return self
 
+    def await_result(self, timeout: int = 100) -> None:
+        """
+        Wait until the request is finished.
+        """
+        self._reactor.run()
+        x = 0
+
+        while not self.result.get("done"):
+            # If there's a producer, tell it to resume producing so we get content
+            if self._producer:
+                self._producer.resumeProducing()
+
+            x += 1
+
+            if x > timeout:
+                raise TimedOutException("Timed out waiting for request to finish.")
+
+            self._reactor.advance(0.1)
+
 
 class FakeSite:
     """
@@ -241,30 +260,9 @@ def make_request(
     return req, channel
 
 
-def wait_until_result(clock, request, timeout=100):
-    """
-    Wait until the request is finished.
-    """
-    clock.run()
-    x = 0
-
-    while not request.finished:
-
-        # If there's a producer, tell it to resume producing so we get content
-        if request._channel._producer:
-            request._channel._producer.resumeProducing()
-
-        x += 1
-
-        if x > timeout:
-            raise TimedOutException("Timed out waiting for request to finish.")
-
-        clock.advance(0.1)
-
-
 def render(request, resource, clock):
     request.render(resource)
-    wait_until_result(clock, request)
+    request._channel.await_result()
 
 
 @implementer(IReactorPluggableNameResolver)