diff --git a/changelog.d/8751.misc b/changelog.d/8751.misc
new file mode 100644
index 0000000000..204c280c0e
--- /dev/null
+++ b/changelog.d/8751.misc
@@ -0,0 +1 @@
+Generalise `RoomMemberHandler._locally_reject_invite` to apply to more flows than just invite.
\ No newline at end of file
diff --git a/changelog.d/8758.misc b/changelog.d/8758.misc
new file mode 100644
index 0000000000..54502e9b90
--- /dev/null
+++ b/changelog.d/8758.misc
@@ -0,0 +1 @@
+Refactor test utilities for injecting HTTP requests.
diff --git a/synapse/handlers/room_member.py b/synapse/handlers/room_member.py
index 7cd858b7db..fd85e08973 100644
--- a/synapse/handlers/room_member.py
+++ b/synapse/handlers/room_member.py
@@ -1104,32 +1104,34 @@ class RoomMemberMasterHandler(RoomMemberHandler):
#
logger.warning("Failed to reject invite: %s", e)
- return await self._locally_reject_invite(
+ return await self._generate_local_out_of_band_leave(
invite_event, txn_id, requester, content
)
- async def _locally_reject_invite(
+ async def _generate_local_out_of_band_leave(
self,
- invite_event: EventBase,
+ previous_membership_event: EventBase,
txn_id: Optional[str],
requester: Requester,
content: JsonDict,
) -> Tuple[str, int]:
- """Generate a local invite rejection
+ """Generate a local leave event for a room
- This is called after we fail to reject an invite via a remote server. It
- generates an out-of-band membership event locally.
+ This can be called after we e.g fail to reject an invite via a remote server.
+ It generates an out-of-band membership event locally.
Args:
- invite_event: the invite to be rejected
+ previous_membership_event: the previous membership event for this user
txn_id: optional transaction ID supplied by the client
- requester: user making the rejection request, according to the access token
- content: additional content to include in the rejection event.
+ requester: user making the request, according to the access token
+ content: additional content to include in the leave event.
Normally an empty dict.
- """
- room_id = invite_event.room_id
- target_user = invite_event.state_key
+ Returns:
+ A tuple containing (event_id, stream_id of the leave event)
+ """
+ room_id = previous_membership_event.room_id
+ target_user = previous_membership_event.state_key
content["membership"] = Membership.LEAVE
@@ -1141,12 +1143,12 @@ class RoomMemberMasterHandler(RoomMemberHandler):
"state_key": target_user,
}
- # the auth events for the new event are the same as that of the invite, plus
- # the invite itself.
+ # the auth events for the new event are the same as that of the previous event, plus
+ # the event itself.
#
- # the prev_events are just the invite.
- prev_event_ids = [invite_event.event_id]
- auth_event_ids = invite_event.auth_event_ids() + prev_event_ids
+ # the prev_events consist solely of the previous membership event.
+ prev_event_ids = [previous_membership_event.event_id]
+ auth_event_ids = previous_membership_event.auth_event_ids() + prev_event_ids
event, context = await self.event_creation_handler.create_event(
requester,
diff --git a/tests/rest/key/v2/test_remote_key_resource.py b/tests/rest/key/v2/test_remote_key_resource.py
index 6671cbd32d..fbcf8d5b86 100644
--- a/tests/rest/key/v2/test_remote_key_resource.py
+++ b/tests/rest/key/v2/test_remote_key_resource.py
@@ -32,7 +32,7 @@ from synapse.util.httpresourcetree import create_resource_tree
from synapse.util.stringutils import random_string
from tests import unittest
-from tests.server import FakeChannel, wait_until_result
+from tests.server import FakeChannel
from tests.utils import default_config
@@ -94,7 +94,7 @@ class RemoteKeyResourceTestCase(BaseRemoteKeyResourceTestCase):
% (server_name.encode("utf-8"), key_id.encode("utf-8")),
b"1.1",
)
- wait_until_result(self.reactor, req)
+ channel.await_result()
self.assertEqual(channel.code, 200)
resp = channel.json_body
return resp
@@ -190,7 +190,7 @@ class EndToEndPerspectivesTests(BaseRemoteKeyResourceTestCase):
req.requestReceived(
b"POST", path.encode("utf-8"), b"1.1",
)
- wait_until_result(self.reactor, req)
+ channel.await_result()
self.assertEqual(channel.code, 200)
resp = channel.json_body
return resp
diff --git a/tests/server.py b/tests/server.py
index 5850eadf3e..5a1583a3e7 100644
--- a/tests/server.py
+++ b/tests/server.py
@@ -118,6 +118,25 @@ class FakeChannel:
def transport(self):
return self
+ def await_result(self, timeout: int = 100) -> None:
+ """
+ Wait until the request is finished.
+ """
+ self._reactor.run()
+ x = 0
+
+ while not self.result.get("done"):
+ # If there's a producer, tell it to resume producing so we get content
+ if self._producer:
+ self._producer.resumeProducing()
+
+ x += 1
+
+ if x > timeout:
+ raise TimedOutException("Timed out waiting for request to finish.")
+
+ self._reactor.advance(0.1)
+
class FakeSite:
"""
@@ -241,30 +260,9 @@ def make_request(
return req, channel
-def wait_until_result(clock, request, timeout=100):
- """
- Wait until the request is finished.
- """
- clock.run()
- x = 0
-
- while not request.finished:
-
- # If there's a producer, tell it to resume producing so we get content
- if request._channel._producer:
- request._channel._producer.resumeProducing()
-
- x += 1
-
- if x > timeout:
- raise TimedOutException("Timed out waiting for request to finish.")
-
- clock.advance(0.1)
-
-
def render(request, resource, clock):
request.render(resource)
- wait_until_result(clock, request)
+ request._channel.await_result()
@implementer(IReactorPluggableNameResolver)
|